Skip to content

Commit 121733e

Browse files
Wanglongzhi2001RichardWooSJTU
authored andcommitted
support cachekv_quant in blha
1 parent 93fda0a commit 121733e

File tree

6 files changed

+1178
-52
lines changed

6 files changed

+1178
-52
lines changed

paddle/phi/api/yaml/fused_ops.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,14 @@
3333
data_type : x
3434

3535
- op : block_multihead_attention_
36-
args : (Tensor qkv, Tensor key_cache, Tensor value_cache, Tensor seq_lens_encoder, Tensor seq_lens_decoder, Tensor seq_lens_this_time, Tensor padding_offsets, Tensor cum_offsets, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor block_tables, Tensor pre_key_cache, Tensor pre_value_cache, Tensor rope_emb, Tensor mask, Tensor tgt_mask, int max_seq_len, int block_size, bool use_neox_style)
36+
args : (Tensor qkv, Tensor key_cache, Tensor value_cache, Tensor seq_lens_encoder, Tensor seq_lens_decoder, Tensor seq_lens_this_time, Tensor padding_offsets, Tensor cum_offsets, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor block_tables, Tensor pre_key_cache, Tensor pre_value_cache, Tensor rope_emb, Tensor mask, Tensor tgt_mask, Tensor cache_k_quant_scales, Tensor cache_v_quant_scales, Tensor cache_k_dequant_scales, Tensor cache_v_dequant_scales, Tensor qkv_out_scale, Tensor qkv_bias, Tensor out_shift, Tensor out_smooth, int max_seq_len, int block_size, bool use_neox_style, bool dynamic_cachekv_quant=false, int quant_round_type=1, float quant_max_bound=127.0, float quant_min_bound=-127.0, float out_scale=-1, str compute_dtype = "default")
3737
output : Tensor(fmha_out), Tensor(qkv_out), Tensor(key_cache_out), Tensor(value_cache_out)
3838
infer_meta :
3939
func : BlockMultiheadAttentionInferMeta
4040
kernel :
4141
func : block_multihead_attention
4242
data_type : qkv
43-
optional : pre_key_cache, pre_value_cache, rope_emb, mask, tgt_mask
43+
optional : pre_key_cache, pre_value_cache, rope_emb, mask, tgt_mask, cache_k_quant_scales, cache_v_quant_scales, cache_k_dequant_scales, cache_v_dequant_scales, qkv_out_scale, qkv_bias, out_shift, out_smooth
4444
inplace : (qkv -> qkv_out), (key_cache -> key_cache_out), (value_cache -> value_cache_out)
4545
support_dygraph_mode : true
4646

paddle/phi/infermeta/fusion.cc

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,23 @@ void BlockMultiheadAttentionInferMeta(const MetaTensor& qkv,
132132
const MetaTensor& rope_emb,
133133
const MetaTensor& mask,
134134
const MetaTensor& tgt_mask,
135+
const MetaTensor& cache_k_quant_scales,
136+
const MetaTensor& cache_v_quant_scales,
137+
const MetaTensor& cache_k_dequant_scales,
138+
const MetaTensor& cache_v_dequant_scales,
139+
const MetaTensor& qkv_out_scale,
140+
const MetaTensor& qkv_bias,
141+
const MetaTensor& out_shift,
142+
const MetaTensor& out_smooth,
135143
int max_seq_len,
136144
int block_size,
137145
bool use_neox_style,
146+
bool dynamic_cachekv_quant,
147+
const int quant_round_type,
148+
const float quant_max_bound,
149+
const float quant_min_bound,
150+
const float out_scale,
151+
const std::string& compute_dtype,
138152
MetaTensor* fmha_out,
139153
MetaTensor* qkv_out,
140154
MetaTensor* key_cache_out,
@@ -159,13 +173,74 @@ void BlockMultiheadAttentionInferMeta(const MetaTensor& qkv,
159173
"The input_dims[1] must be equal to 3 * num_head * dim_head"));
160174

161175
fmha_out->set_dims({input_dims[0], num_head * dim_head});
162-
fmha_out->set_dtype(qkv.dtype());
163176
qkv_out->set_dims(qkv.dims());
164-
qkv_out->set_dtype(qkv.dtype());
165177
key_cache_out->set_dims(key_cache_dims);
166178
key_cache_out->set_dtype(key_cache.dtype());
167179
value_cache_out->set_dims(key_cache_dims);
168180
value_cache_out->set_dtype(value_cache.dtype());
181+
182+
auto FBADtypeCheck = [](const MetaTensor& check_tensor,
183+
const std::string& tensor_name,
184+
const std::string& compute_dtype) {
185+
if (compute_dtype == "bf16") {
186+
PADDLE_ENFORCE_EQ(
187+
check_tensor.dtype(),
188+
phi::DataType::BFLOAT16,
189+
phi::errors::InvalidArgument(
190+
"Input(%s) dtype must be the same with Attr(compute_dtype)",
191+
tensor_name));
192+
} else if (compute_dtype == "fp16") {
193+
PADDLE_ENFORCE_EQ(
194+
check_tensor.dtype(),
195+
phi::DataType::FLOAT16,
196+
phi::errors::InvalidArgument(
197+
"Input(%s) dtype must be the same with Attr(compute_dtype)",
198+
tensor_name));
199+
} else if (compute_dtype == "fp32") {
200+
PADDLE_ENFORCE_EQ(
201+
check_tensor.dtype(),
202+
phi::DataType::FLOAT32,
203+
phi::errors::InvalidArgument(
204+
"Input(%s) dtype must be the same with Attr(compute_dtype)",
205+
tensor_name));
206+
}
207+
};
208+
209+
// In the case of quantization enabled, the dtype for computation is
210+
// determined based on compute_dtype.
211+
if (qkv.dtype() == phi::DataType::INT32) {
212+
PADDLE_ENFORCE_NE(
213+
compute_dtype,
214+
"default",
215+
phi::errors::InvalidArgument(
216+
"If Input(x) dtype is INT32, Attr(compute_dtype) must be set."));
217+
if (out_scale > 0) {
218+
fmha_out->set_dtype(phi::DataType::INT8);
219+
} else {
220+
if (compute_dtype == "bf16") {
221+
fmha_out->set_dtype(phi::DataType::BFLOAT16);
222+
} else if (compute_dtype == "fp16") {
223+
fmha_out->set_dtype(phi::DataType::FLOAT16);
224+
} else if (compute_dtype == "fp32") {
225+
fmha_out->set_dtype(phi::DataType::FLOAT32);
226+
} else {
227+
PADDLE_THROW(phi::errors::InvalidArgument(
228+
"In the case of quantization enabled with Input(x) INT32, "
229+
"Attr(compute_dtype) must be set in (bf16, fp16, fp32), "
230+
"but get compute_dtype (%s)",
231+
compute_dtype));
232+
}
233+
}
234+
} else {
235+
if (compute_dtype != "default") {
236+
FBADtypeCheck(qkv, "qkv", compute_dtype);
237+
}
238+
if (out_scale > 0) {
239+
fmha_out->set_dtype(phi::DataType::INT8);
240+
} else {
241+
fmha_out->set_dtype(qkv.dtype());
242+
}
243+
}
169244
}
170245

171246
void Conv1dXPUInferMeta(const MetaTensor& x,

paddle/phi/infermeta/fusion.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,23 @@ void BlockMultiheadAttentionInferMeta(const MetaTensor& qkv,
5454
const MetaTensor& rope_emb,
5555
const MetaTensor& mask,
5656
const MetaTensor& tgt_mask,
57+
const MetaTensor& cache_k_quant_scales,
58+
const MetaTensor& cache_v_quant_scales,
59+
const MetaTensor& cache_k_dequant_scales,
60+
const MetaTensor& cache_v_dequant_scales,
61+
const MetaTensor& qkv_out_scale,
62+
const MetaTensor& qkv_bias,
63+
const MetaTensor& out_shift,
64+
const MetaTensor& out_smooth,
5765
int max_seq_len,
5866
int block_size,
5967
bool use_neox_style,
68+
bool dynamic_cachekv_quant,
69+
const int quant_round_type,
70+
const float quant_max_bound,
71+
const float quant_min_bound,
72+
const float out_scale,
73+
const std::string& compute_dtype,
6074
MetaTensor* fmha_out,
6175
MetaTensor* qkv_out,
6276
MetaTensor* key_cache_out,

0 commit comments

Comments
 (0)