Skip to content

[xpu] multi_encoder_xpu supoort smooth quant, skip quant and local quant #61617

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
349 changes: 267 additions & 82 deletions paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc

Large diffs are not rendered by default.

11 changes: 8 additions & 3 deletions paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ struct PatternParam {
bool norm_before;
bool with_q_scale;
bool with_mask;
bool is_smooth_quant;
};

class MultiEncoderXPUFusePass : public FusePassBase {
Expand All @@ -142,7 +143,8 @@ class MultiEncoderXPUFusePass : public FusePassBase {
const std::string& matmul_type_2,
bool norm_before,
bool with_q_scale,
bool with_mask) const;
bool with_mask,
bool is_smooth_qunat) const;

bool ApplyMultiEncoderXPUFuse(ir::Graph* graph) const;

Expand All @@ -152,7 +154,7 @@ class MultiEncoderXPUFusePass : public FusePassBase {
// 1. Transpose q_w, k_w, v_w
// 2. Concat q_w, k_w, v_w
// 3. Generate qkv_w_max tensor
// 4. Quant qkv_w to int16
// 4. Quant qkv_w to int16/int8 or cast to float16 (local quant)
void PrepareQKVWeight(
Graph* graph,
Scope* scope,
Expand All @@ -161,6 +163,7 @@ class MultiEncoderXPUFusePass : public FusePassBase {
Node* k_w,
Node* v_w,
bool enable_int8,
bool local_quant,
std::unordered_map<std::string, std::vector<float>>* var_quant_scales,
Node** qkv_w,
Node** qkv_w_max,
Expand All @@ -171,7 +174,9 @@ class MultiEncoderXPUFusePass : public FusePassBase {
BlockDesc* block,
std::unordered_map<std::string, std::vector<Node*>>* node_maps,
std::unordered_map<std::string, std::vector<float>>* var_quant_scales,
std::vector<Node*>* input_max_nodes) const;
std::vector<Node*>* input_max_nodes,
std::vector<std::string>* quant_types,
const std::string* act_type) const;

// 1. Cast bias to fp32
// 2. Concat q/k/v bias
Expand Down
70 changes: 70 additions & 0 deletions paddle/fluid/framework/ir/xpu/pass_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/xpu/xpu_api_wrapper.h"

namespace paddle {
Expand Down Expand Up @@ -123,11 +124,68 @@ template size_t HashTensor<int16_t>(const phi::DenseTensor& in);
template size_t HashTensor<float>(const phi::DenseTensor& in);
template size_t HashTensor<int8_t>(const phi::DenseTensor& in);

template <>
size_t HashTensor<float16>(const phi::DenseTensor& in) {
phi::DenseTensor dst_tensor;
auto* cpu_ctx = static_cast<phi::CPUContext*>(
platform::DeviceContextPool::Instance().Get(phi::CPUPlace()));
dst_tensor.Resize(in.dims());
dst_tensor.set_type(phi::DataType::FLOAT32);
dst_tensor.set_layout(in.layout());
phi::CastKernel<float16>(*cpu_ctx, in, phi::DataType::FLOAT32, &dst_tensor);
return HashTensor<float>(dst_tensor);
}

std::string GetPrefixWithoutHash(const std::string& name) {
std::size_t found = name.find("_#");
return found == std::string::npos ? name : name.substr(0, found);
}

void ConvertFromFp32ToFp16(phi::DenseTensor* weight,
phi::DenseTensor* weight_max,
bool transpose) {
// Convert fp16 to fp32
phi::DenseTensor weight_fp32;
CastToFp32(weight, &weight_fp32);

if (transpose) { // (k, n) -> (n, k)
Transpose2D(&weight_fp32);
}

auto FindMaxAbs = [](const float* data, int len) {
float max_f = 0.0f;
for (int i = 0; i < len; ++i) {
float max = std::abs(data[i]);
if (max > max_f) {
max_f = max;
}
}
return max_f;
};

auto* cpu_ctx = static_cast<phi::CPUContext*>(
platform::DeviceContextPool::Instance().Get(phi::CPUPlace()));
// Convert to fp16
phi::DenseTensor weight_fp16;
CastToFp16(&weight_fp32, &weight_fp16);
// Find max
int max_ptr_size = phi::backends::xpu::get_xpu_max_ptr_size(-1);
int size = weight_fp32.numel();
float max_val = FindMaxAbs(weight_fp32.data<float>(), size);
std::vector<float> max_vec(max_ptr_size, max_val);
weight_max->set_type(phi::DataType::FLOAT32);
weight_max->Resize({max_ptr_size});
memcpy(cpu_ctx->Alloc<float>(weight_max),
max_vec.data(),
max_ptr_size * sizeof(float));
weight->clear();
weight->set_type(phi::DataType::FLOAT16);
weight->Resize({size});
memcpy(cpu_ctx->Alloc<float16>(weight),
weight_fp16.data<float16>(),
size * sizeof(float16));
}

template <typename Tcpu, typename Txpu>
void PrepareWeight(Graph* graph,
Scope* scope,
Expand Down Expand Up @@ -268,6 +326,18 @@ template void PrepareWeight<float, float>(
const std::vector<float>& weight_scales,
bool per_channel_quant = false);

template void PrepareWeight<float, float16>(
Graph* graph,
Scope* scope,
BlockDesc* block,
Node* weight,
Node** dst_weight,
Node** dst_weight_max,
Node** dst_scale_max,
bool transpose,
const std::vector<float>& weight_scales,
bool per_channel_quant = false);

template void PrepareWeight<float, int16_t>(
Graph* graph,
Scope* scope,
Expand Down
12 changes: 10 additions & 2 deletions paddle/fluid/framework/ir/xpu/pass_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ std::vector<Node*> FindOpNodeByInputName(Graph* graph,
template <typename T>
size_t HashTensor(const phi::DenseTensor& in);

void ConvertFromFp32ToFp16(phi::DenseTensor* weight,
phi::DenseTensor* weight_max,
bool transpose);

template <typename Tcpu,
typename Txpu,
typename std::enable_if<!std::is_same<Tcpu, Txpu>::value, Tcpu>::type*
Expand All @@ -67,8 +71,12 @@ void ConvertWeightWrapper(phi::DenseTensor* weight,
bool transpose,
const std::vector<float>& weight_scales,
bool per_channel_quant) {
ConvertWithQuant<Tcpu, Txpu>(
weight, weight_max, scale_max, transpose, per_channel_quant);
if (std::is_same<Tcpu, float>::value && std::is_same<Txpu, float16>::value) {
ConvertFromFp32ToFp16(weight, weight_max, transpose);
} else {
ConvertWithQuant<Tcpu, Txpu>(
weight, weight_max, scale_max, transpose, per_channel_quant);
}
}

template <typename Tcpu,
Expand Down
1 change: 0 additions & 1 deletion paddle/fluid/framework/ir/xpu/quant_dequant_xpu_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@ void QuantDequantXPUPass::CollectInputScalesFromQuantize(
if (out->Name() == out_var_name) {
for (auto* var : out->outputs) {
auto op_desc = var->Op();
std::string quantized_op_type = op_desc->Type();
op_desc->SetAttr("enable_int8", true);
op_desc->Flush();
}
Expand Down
60 changes: 33 additions & 27 deletions paddle/fluid/framework/ir/xpu/quant_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,41 +115,47 @@ void CastToInt32(phi::DenseTensor* in, phi::DenseTensor* out) {
Assign(*out_ptr, in);
}
}

void CastToFp32(phi::DenseTensor* in, phi::DenseTensor* out) {
void CastTo(phi::DenseTensor* in, phi::DenseTensor* out, DataType out_dtype) {
auto* cpu_ctx = static_cast<phi::CPUContext*>(
platform::DeviceContextPool::Instance().Get(phi::CPUPlace()));

paddle::experimental::CheckAndTrans2Contiguous(in);
if (in->dtype() != phi::DataType::FLOAT16 &&
in->dtype() != phi::DataType::FLOAT32) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Only support fp16 and fp32, but received dtype is %s.",
phi::DataTypeToString(in->dtype())));
}

phi::DenseTensor fp32_tensor;
phi::DenseTensor* out_ptr = out == nullptr ? &fp32_tensor : out;
paddle::experimental::CheckAndTrans2Contiguous(in);
phi::DenseTensor ori_tensor;
phi::DenseTensor* out_ptr = out == nullptr ? &ori_tensor : out;
out_ptr->Resize(in->dims());
out_ptr->set_type(phi::DataType::FLOAT32);
out_ptr->set_type(out_dtype);
out_ptr->set_layout(in->layout());

switch (in->dtype()) {
case phi::DataType::FLOAT16:
phi::CastKernel<phi::dtype::float16>(
*cpu_ctx, *in, phi::DataType::FLOAT32, out_ptr);
break;
case phi::DataType::FLOAT32:
if (out == nullptr) {
return;
} else {
phi::AssignKernel(*cpu_ctx, *in, out_ptr);
}
break;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Only support fp16 and fp32, but received dtype is %s.",
phi::DataTypeToString(in->dtype())));
break;
if (in->dtype() == out_dtype) {
if (out == nullptr) {
return;
} else {
phi::AssignKernel(*cpu_ctx, *in, out_ptr);
}
} else {
if (in->dtype() == phi::DataType::FLOAT16) {
phi::CastKernel<float16>(*cpu_ctx, *in, out_dtype, out_ptr);
} else {
phi::CastKernel<float>(*cpu_ctx, *in, out_dtype, out_ptr);
}
if (out == nullptr) {
Assign(*out_ptr, in);
}
}
}

if (out == nullptr) {
Assign(*out_ptr, in);
}
void CastToFp32(phi::DenseTensor* in, phi::DenseTensor* out) {
CastTo(in, out, phi::DataType::FLOAT32);
}

void CastToFp16(phi::DenseTensor* in, phi::DenseTensor* out) {
CastTo(in, out, phi::DataType::FLOAT16);
}

static float FindMaxAbs(const float* data, int len) {
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/framework/ir/xpu/quant_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,12 @@ void Assign(const phi::DenseTensor& in, phi::DenseTensor* out);

void Transpose2D(phi::DenseTensor* in, phi::DenseTensor* out = nullptr);

void CastTo(phi::DenseTensor* in, phi::DenseTensor* out, DataType dtype);

void CastToFp32(phi::DenseTensor* in, phi::DenseTensor* out = nullptr);

void CastToFp16(phi::DenseTensor* in, phi::DenseTensor* out = nullptr);

void CastToInt32(phi::DenseTensor* in, phi::DenseTensor* out = nullptr);

template <typename T>
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ void CpuPassStrategy::EraseFcMkldnnPasses() {

XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
passes_.assign({
"quant_dequant_xpu_pass",
// "quant_dequant_xpu_pass", open this pass when use old int8 model
"delete_quant_dequant_linear_op_pass",
"delete_weight_dequant_linear_op_pass",
"delete_assign_op_pass",
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/api/yaml/fused_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@
backward : max_pool2d_v2_grad

- op : multi_encoder_xpu
args : (Tensor x, Tensor[] fc_input_max, Tensor[] fc_weight, Tensor[] fc_weight_max, Tensor[] fc_bias, Tensor[] ln_scale, Tensor[] ln_bias, Tensor mask, Tensor seq_lod, Tensor max_seq_len, int layer_num, bool norm_before, int hidden_dim, int head_num, int size_per_head, int ffn_hidden_dim_scale, int act_type, int relative_type, int slice_idx, bool is_per_channel)
args : (Tensor x, Tensor[] fc_input_max, Tensor[] fc_weight, Tensor[] fc_weight_max, Tensor[] fc_bias, Tensor[] ln_scale, Tensor[] ln_bias, Tensor[] smooth_scale_weight, Tensor mask, Tensor seq_lod, Tensor max_seq_len, int layer_num, bool norm_before, int hidden_dim, int head_num, int size_per_head, int ffn_hidden_dim_scale, int act_type, int relative_type, int slice_idx, bool is_per_channel, float[] softmax_max_value, str[] quant_types)
output : Tensor(out), Tensor(x_fp16), Tensor(out_fp16)
infer_meta :
func : MultiEncoderXPUInferMeta
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/infermeta/fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1446,6 +1446,7 @@ void MultiEncoderXPUInferMeta(
const std::vector<const MetaTensor*>& fc_bias,
const std::vector<const MetaTensor*>& ln_scale,
const std::vector<const MetaTensor*>& ln_bias,
const std::vector<const MetaTensor*>& smooth_scale_weight,
const MetaTensor& mask,
const MetaTensor& seq_lod,
const MetaTensor& max_seq_len,
Expand All @@ -1459,6 +1460,8 @@ void MultiEncoderXPUInferMeta(
int relative_type,
int slice_idx,
bool is_per_channel,
const std::vector<float>& softmax_max_value,
const std::vector<std::string>& quant_types,
MetaTensor* out,
MetaTensor* x_fp16,
MetaTensor* out_fp16) {
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/infermeta/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ void MultiEncoderXPUInferMeta(
const std::vector<const MetaTensor*>& fc_bias,
const std::vector<const MetaTensor*>& ln_scale,
const std::vector<const MetaTensor*>& ln_bias,
const std::vector<const MetaTensor*>& smooth_scale_weight,
const MetaTensor& mask,
const MetaTensor& seq_lod,
const MetaTensor& max_seq_len,
Expand All @@ -163,6 +164,8 @@ void MultiEncoderXPUInferMeta(
int relative_type,
int slice_idx,
bool is_per_channel,
const std::vector<float>& softmax_max_value,
const std::vector<std::string>& quant_types,
MetaTensor* out,
MetaTensor* x_fp16,
MetaTensor* out_fp16);
Expand Down
Loading