Skip to content

Commit d8cb4ef

Browse files
committed
[xpu] multi_encoder_xpu supoort smooth quant, skip quant and local quant
1 parent 4f1bffe commit d8cb4ef

11 files changed

+508
-167
lines changed

paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc

Lines changed: 267 additions & 82 deletions
Large diffs are not rendered by default.

paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ struct PatternParam {
128128
bool norm_before;
129129
bool with_q_scale;
130130
bool with_mask;
131+
bool is_smooth_quant;
131132
};
132133

133134
class MultiEncoderXPUFusePass : public FusePassBase {
@@ -142,7 +143,8 @@ class MultiEncoderXPUFusePass : public FusePassBase {
142143
const std::string& matmul_type_2,
143144
bool norm_before,
144145
bool with_q_scale,
145-
bool with_mask) const;
146+
bool with_mask,
147+
bool is_smooth_qunat) const;
146148

147149
bool ApplyMultiEncoderXPUFuse(ir::Graph* graph) const;
148150

@@ -152,7 +154,7 @@ class MultiEncoderXPUFusePass : public FusePassBase {
152154
// 1. Transpose q_w, k_w, v_w
153155
// 2. Concat q_w, k_w, v_w
154156
// 3. Generate qkv_w_max tensor
155-
// 4. Quant qkv_w to int16
157+
// 4. Quant qkv_w to int16/int8 or cast to float16 (local quant)
156158
void PrepareQKVWeight(
157159
Graph* graph,
158160
Scope* scope,
@@ -161,6 +163,7 @@ class MultiEncoderXPUFusePass : public FusePassBase {
161163
Node* k_w,
162164
Node* v_w,
163165
bool enable_int8,
166+
bool local_quant,
164167
std::unordered_map<std::string, std::vector<float>>* var_quant_scales,
165168
Node** qkv_w,
166169
Node** qkv_w_max,
@@ -171,7 +174,9 @@ class MultiEncoderXPUFusePass : public FusePassBase {
171174
BlockDesc* block,
172175
std::unordered_map<std::string, std::vector<Node*>>* node_maps,
173176
std::unordered_map<std::string, std::vector<float>>* var_quant_scales,
174-
std::vector<Node*>* input_max_nodes) const;
177+
std::vector<Node*>* input_max_nodes,
178+
std::vector<std::string>* quant_types,
179+
const std::string* act_type) const;
175180

176181
// 1. Cast bias to fp32
177182
// 2. Concat q/k/v bias

paddle/fluid/framework/ir/xpu/pass_utils.cc

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,11 +123,68 @@ template size_t HashTensor<int16_t>(const phi::DenseTensor& in);
123123
template size_t HashTensor<float>(const phi::DenseTensor& in);
124124
template size_t HashTensor<int8_t>(const phi::DenseTensor& in);
125125

126+
template <>
127+
size_t HashTensor<float16>(const phi::DenseTensor& in) {
128+
phi::DenseTensor dst_tensor;
129+
auto* cpu_ctx = static_cast<phi::CPUContext*>(
130+
platform::DeviceContextPool::Instance().Get(phi::CPUPlace()));
131+
dst_tensor.Resize(in->dims());
132+
dst_tensor.set_type(phi::DataType::FLOAT32);
133+
dst_tensor.set_layout(in->layout());
134+
phi::CastKernel<float16>(*cpu_ctx, in, float, &dst_tensor);
135+
return HashTensor<float>(dst_tensor);
136+
}
137+
126138
std::string GetPrefixWithoutHash(const std::string& name) {
127139
std::size_t found = name.find("_#");
128140
return found == std::string::npos ? name : name.substr(0, found);
129141
}
130142

143+
void ConvertFromFp32ToFp16(phi::DenseTensor* weight,
144+
phi::DenseTensor* weight_max,
145+
bool transpose) {
146+
// Convert fp16 to fp32
147+
phi::DenseTensor weight_fp32;
148+
CastToFp32(weight, &weight_fp32);
149+
150+
if (transpose) { // (k, n) -> (n, k)
151+
Transpose2D(&weight_fp32);
152+
}
153+
154+
auto FindMaxAbs = [](const float* data, int len) {
155+
float max_f = 0.0f;
156+
for (int i = 0; i < len; ++i) {
157+
float max = std::abs(data[i]);
158+
if (max > max_f) {
159+
max_f = max;
160+
}
161+
}
162+
return max_f;
163+
};
164+
165+
auto* cpu_ctx = static_cast<phi::CPUContext*>(
166+
platform::DeviceContextPool::Instance().Get(phi::CPUPlace()));
167+
// Convert to fp16
168+
phi::DenseTensor weight_fp16;
169+
CastToFp16(&weight_fp32, &weight_fp16);
170+
// Find max
171+
int max_ptr_size = phi::backends::xpu::get_xpu_max_ptr_size(-1);
172+
int size = weight_fp32.numel();
173+
float max_val = FindMaxAbs(weight_fp32.data<float>(), size);
174+
std::vector<float> max_vec(max_ptr_size, max_val);
175+
weight_max->set_type(phi::DataType::FLOAT32);
176+
weight_max->Resize({max_ptr_size});
177+
memcpy(cpu_ctx->Alloc<float>(weight_max),
178+
max_vec.data(),
179+
max_ptr_size * sizeof(float));
180+
weight->clear();
181+
weight->set_type(phi::DataType::FLOAT16);
182+
weight->Resize({size});
183+
memcpy(cpu_ctx->Alloc<float16>(weight),
184+
weight_fp16.data<float16>(),
185+
size * sizeof(float16));
186+
}
187+
131188
template <typename Tcpu, typename Txpu>
132189
void PrepareWeight(Graph* graph,
133190
Scope* scope,
@@ -268,6 +325,18 @@ template void PrepareWeight<float, float>(
268325
const std::vector<float>& weight_scales,
269326
bool per_channel_quant = false);
270327

328+
template void PrepareWeight<float, float16>(
329+
Graph* graph,
330+
Scope* scope,
331+
BlockDesc* block,
332+
Node* weight,
333+
Node** dst_weight,
334+
Node** dst_weight_max,
335+
Node** dst_scale_max,
336+
bool transpose,
337+
const std::vector<float>& weight_scales,
338+
bool per_channel_quant = false);
339+
271340
template void PrepareWeight<float, int16_t>(
272341
Graph* graph,
273342
Scope* scope,

paddle/fluid/framework/ir/xpu/pass_utils.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ std::vector<Node*> FindOpNodeByInputName(Graph* graph,
5757
template <typename T>
5858
size_t HashTensor(const phi::DenseTensor& in);
5959

60+
void ConvertFromFp32ToFp16(phi::DenseTensor* weight,
61+
phi::DenseTensor* weight_max,
62+
bool transpose);
63+
6064
template <typename Tcpu,
6165
typename Txpu,
6266
typename std::enable_if<!std::is_same<Tcpu, Txpu>::value, Tcpu>::type*
@@ -67,8 +71,12 @@ void ConvertWeightWrapper(phi::DenseTensor* weight,
6771
bool transpose,
6872
const std::vector<float>& weight_scales,
6973
bool per_channel_quant) {
70-
ConvertWithQuant<Tcpu, Txpu>(
71-
weight, weight_max, scale_max, transpose, per_channel_quant);
74+
if (std::is_same<Tcpu, float>::value && std::is_same<Txpu, float16>::value) {
75+
ConvertFromFp32ToFp16(weight, weight_max, transpose);
76+
} else {
77+
ConvertWithQuant<Tcpu, Txpu>(
78+
weight, weight_max, scale_max, transpose, per_channel_quant);
79+
}
7280
}
7381

7482
template <typename Tcpu,

paddle/fluid/framework/ir/xpu/quant_dequant_xpu_pass.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,6 @@ void QuantDequantXPUPass::CollectInputScalesFromQuantize(
191191
if (out->Name() == out_var_name) {
192192
for (auto* var : out->outputs) {
193193
auto op_desc = var->Op();
194-
std::string quantized_op_type = op_desc->Type();
195194
op_desc->SetAttr("enable_int8", true);
196195
op_desc->Flush();
197196
}

paddle/fluid/framework/ir/xpu/quant_utils.cc

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -115,41 +115,47 @@ void CastToInt32(phi::DenseTensor* in, phi::DenseTensor* out) {
115115
Assign(*out_ptr, in);
116116
}
117117
}
118-
119-
void CastToFp32(phi::DenseTensor* in, phi::DenseTensor* out) {
118+
void CastTo(phi::DenseTensor* in, phi::DenseTensor* out, DataType out_dtype) {
120119
auto* cpu_ctx = static_cast<phi::CPUContext*>(
121120
platform::DeviceContextPool::Instance().Get(phi::CPUPlace()));
122121

123-
paddle::experimental::CheckAndTrans2Contiguous(in);
122+
if (in->dtype() != phi::DataType::FLOAT16 &&
123+
in->dtype() != phi::DataType::FLOAT32) {
124+
PADDLE_THROW(platform::errors::InvalidArgument(
125+
"Only support fp16 and fp32, but received dtype is %s.",
126+
phi::DataTypeToString(in->dtype())));
127+
}
124128

125-
phi::DenseTensor fp32_tensor;
126-
phi::DenseTensor* out_ptr = out == nullptr ? &fp32_tensor : out;
129+
paddle::experimental::CheckAndTrans2Contiguous(in);
130+
phi::DenseTensor ori_tensor;
131+
phi::DenseTensor* out_ptr = out == nullptr ? &ori_tensor : out;
127132
out_ptr->Resize(in->dims());
128-
out_ptr->set_type(phi::DataType::FLOAT32);
133+
out_ptr->set_type(out_dtype);
129134
out_ptr->set_layout(in->layout());
130-
131-
switch (in->dtype()) {
132-
case phi::DataType::FLOAT16:
133-
phi::CastKernel<phi::dtype::float16>(
134-
*cpu_ctx, *in, phi::DataType::FLOAT32, out_ptr);
135-
break;
136-
case phi::DataType::FLOAT32:
137-
if (out == nullptr) {
138-
return;
139-
} else {
140-
phi::AssignKernel(*cpu_ctx, *in, out_ptr);
141-
}
142-
break;
143-
default:
144-
PADDLE_THROW(platform::errors::InvalidArgument(
145-
"Only support fp16 and fp32, but received dtype is %s.",
146-
phi::DataTypeToString(in->dtype())));
147-
break;
135+
if (in->dtype() == out_dtype) {
136+
if (out == nullptr) {
137+
return;
138+
} else {
139+
phi::AssignKernel(*cpu_ctx, *in, out_ptr);
140+
}
141+
} else {
142+
if (in->dtype() == phi::DataType::FLOAT16) {
143+
phi::CastKernel<float16>(*cpu_ctx, *in, out_dtype, out_ptr);
144+
} else {
145+
phi::CastKernel<float>(*cpu_ctx, *in, out_dtype, out_ptr);
146+
}
147+
if (out == nullptr) {
148+
Assign(*out_ptr, in);
149+
}
148150
}
151+
}
149152

150-
if (out == nullptr) {
151-
Assign(*out_ptr, in);
152-
}
153+
void CastToFp32(phi::DenseTensor* in, phi::DenseTensor* out) {
154+
CastTo(in, out, phi::DataType::FLOAT32);
155+
}
156+
157+
void CastToFp16(phi::DenseTensor* in, phi::DenseTensor* out) {
158+
CastTo(in, out, phi::DataType::FLOAT16);
153159
}
154160

155161
static float FindMaxAbs(const float* data, int len) {

paddle/fluid/framework/ir/xpu/quant_utils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,12 @@ void Assign(const phi::DenseTensor& in, phi::DenseTensor* out);
2323

2424
void Transpose2D(phi::DenseTensor* in, phi::DenseTensor* out = nullptr);
2525

26+
void CastTo(phi::DenseTensor* in, phi::DenseTensor* out, DataType dtype);
27+
2628
void CastToFp32(phi::DenseTensor* in, phi::DenseTensor* out = nullptr);
2729

30+
void CastToFp16(phi::DenseTensor* in, phi::DenseTensor* out = nullptr);
31+
2832
void CastToInt32(phi::DenseTensor* in, phi::DenseTensor* out = nullptr);
2933

3034
template <typename T>

paddle/phi/api/yaml/fused_ops.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@
399399
backward : max_pool2d_v2_grad
400400

401401
- op : multi_encoder_xpu
402-
args : (Tensor x, Tensor[] fc_input_max, Tensor[] fc_weight, Tensor[] fc_weight_max, Tensor[] fc_bias, Tensor[] ln_scale, Tensor[] ln_bias, Tensor mask, Tensor seq_lod, Tensor max_seq_len, int layer_num, bool norm_before, int hidden_dim, int head_num, int size_per_head, int ffn_hidden_dim_scale, int act_type, int relative_type, int slice_idx, bool is_per_channel)
402+
args : (Tensor x, Tensor[] fc_input_max, Tensor[] fc_weight, Tensor[] fc_weight_max, Tensor[] fc_bias, Tensor[] ln_scale, Tensor[] ln_bias, Tensor[] smooth_scale_weight, Tensor mask, Tensor seq_lod, Tensor max_seq_len, int layer_num, bool norm_before, int hidden_dim, int head_num, int size_per_head, int ffn_hidden_dim_scale, int act_type, int relative_type, int slice_idx, bool is_per_channel, float[] softmax_max_value, str[] quant_types)
403403
output : Tensor(out), Tensor(x_fp16), Tensor(out_fp16)
404404
infer_meta :
405405
func : MultiEncoderXPUInferMeta

paddle/phi/infermeta/fusion.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1446,6 +1446,7 @@ void MultiEncoderXPUInferMeta(
14461446
const std::vector<const MetaTensor*>& fc_bias,
14471447
const std::vector<const MetaTensor*>& ln_scale,
14481448
const std::vector<const MetaTensor*>& ln_bias,
1449+
const std::vector<const MetaTensor*>& smooth_scale_weight,
14491450
const MetaTensor& mask,
14501451
const MetaTensor& seq_lod,
14511452
const MetaTensor& max_seq_len,
@@ -1459,6 +1460,8 @@ void MultiEncoderXPUInferMeta(
14591460
int relative_type,
14601461
int slice_idx,
14611462
bool is_per_channel,
1463+
const std::vector<float>& softmax_max_value,
1464+
const std::vector<std::string>& quant_types,
14621465
MetaTensor* out,
14631466
MetaTensor* x_fp16,
14641467
MetaTensor* out_fp16) {

paddle/phi/infermeta/fusion.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ void MultiEncoderXPUInferMeta(
150150
const std::vector<const MetaTensor*>& fc_bias,
151151
const std::vector<const MetaTensor*>& ln_scale,
152152
const std::vector<const MetaTensor*>& ln_bias,
153+
const std::vector<const MetaTensor*>& smooth_scale_weight,
153154
const MetaTensor& mask,
154155
const MetaTensor& seq_lod,
155156
const MetaTensor& max_seq_len,
@@ -163,6 +164,8 @@ void MultiEncoderXPUInferMeta(
163164
int relative_type,
164165
int slice_idx,
165166
bool is_per_channel,
167+
const std::vector<float>& softmax_max_value,
168+
const std::vector<std::string>& quant_types,
166169
MetaTensor* out,
167170
MetaTensor* x_fp16,
168171
MetaTensor* out_fp16);

0 commit comments

Comments
 (0)