@@ -132,9 +132,23 @@ void BlockMultiheadAttentionInferMeta(const MetaTensor& qkv,
132
132
const MetaTensor& rope_emb,
133
133
const MetaTensor& mask,
134
134
const MetaTensor& tgt_mask,
135
+ const MetaTensor& cache_k_quant_scales,
136
+ const MetaTensor& cache_v_quant_scales,
137
+ const MetaTensor& cache_k_dequant_scales,
138
+ const MetaTensor& cache_v_dequant_scales,
139
+ const MetaTensor& qkv_out_scale,
140
+ const MetaTensor& qkv_bias,
141
+ const MetaTensor& out_shift,
142
+ const MetaTensor& out_smooth,
135
143
int max_seq_len,
136
144
int block_size,
137
145
bool use_neox_style,
146
+ bool dynamic_cachekv_quant,
147
+ const int quant_round_type,
148
+ const float quant_max_bound,
149
+ const float quant_min_bound,
150
+ const float out_scale,
151
+ const std::string& compute_dtype,
138
152
MetaTensor* fmha_out,
139
153
MetaTensor* qkv_out,
140
154
MetaTensor* key_cache_out,
@@ -159,13 +173,74 @@ void BlockMultiheadAttentionInferMeta(const MetaTensor& qkv,
159
173
" The input_dims[1] must be equal to 3 * num_head * dim_head" ));
160
174
161
175
fmha_out->set_dims ({input_dims[0 ], num_head * dim_head});
162
- fmha_out->set_dtype (qkv.dtype ());
163
176
qkv_out->set_dims (qkv.dims ());
164
- qkv_out->set_dtype (qkv.dtype ());
165
177
key_cache_out->set_dims (key_cache_dims);
166
178
key_cache_out->set_dtype (key_cache.dtype ());
167
179
value_cache_out->set_dims (key_cache_dims);
168
180
value_cache_out->set_dtype (value_cache.dtype ());
181
+
182
+ auto FBADtypeCheck = [](const MetaTensor& check_tensor,
183
+ const std::string& tensor_name,
184
+ const std::string& compute_dtype) {
185
+ if (compute_dtype == " bf16" ) {
186
+ PADDLE_ENFORCE_EQ (
187
+ check_tensor.dtype (),
188
+ phi::DataType::BFLOAT16,
189
+ phi::errors::InvalidArgument (
190
+ " Input(%s) dtype must be the same with Attr(compute_dtype)" ,
191
+ tensor_name));
192
+ } else if (compute_dtype == " fp16" ) {
193
+ PADDLE_ENFORCE_EQ (
194
+ check_tensor.dtype (),
195
+ phi::DataType::FLOAT16,
196
+ phi::errors::InvalidArgument (
197
+ " Input(%s) dtype must be the same with Attr(compute_dtype)" ,
198
+ tensor_name));
199
+ } else if (compute_dtype == " fp32" ) {
200
+ PADDLE_ENFORCE_EQ (
201
+ check_tensor.dtype (),
202
+ phi::DataType::FLOAT32,
203
+ phi::errors::InvalidArgument (
204
+ " Input(%s) dtype must be the same with Attr(compute_dtype)" ,
205
+ tensor_name));
206
+ }
207
+ };
208
+
209
+ // In the case of quantization enabled, the dtype for computation is
210
+ // determined based on compute_dtype.
211
+ if (qkv.dtype () == phi::DataType::INT32) {
212
+ PADDLE_ENFORCE_NE (
213
+ compute_dtype,
214
+ " default" ,
215
+ phi::errors::InvalidArgument (
216
+ " If Input(x) dtype is INT32, Attr(compute_dtype) must be set." ));
217
+ if (out_scale > 0 ) {
218
+ fmha_out->set_dtype (phi::DataType::INT8);
219
+ } else {
220
+ if (compute_dtype == " bf16" ) {
221
+ fmha_out->set_dtype (phi::DataType::BFLOAT16);
222
+ } else if (compute_dtype == " fp16" ) {
223
+ fmha_out->set_dtype (phi::DataType::FLOAT16);
224
+ } else if (compute_dtype == " fp32" ) {
225
+ fmha_out->set_dtype (phi::DataType::FLOAT32);
226
+ } else {
227
+ PADDLE_THROW (phi::errors::InvalidArgument (
228
+ " In the case of quantization enabled with Input(x) INT32, "
229
+ " Attr(compute_dtype) must be set in (bf16, fp16, fp32), "
230
+ " but get compute_dtype (%s)" ,
231
+ compute_dtype));
232
+ }
233
+ }
234
+ } else {
235
+ if (compute_dtype != " default" ) {
236
+ FBADtypeCheck (qkv, " qkv" , compute_dtype);
237
+ }
238
+ if (out_scale > 0 ) {
239
+ fmha_out->set_dtype (phi::DataType::INT8);
240
+ } else {
241
+ fmha_out->set_dtype (qkv.dtype ());
242
+ }
243
+ }
169
244
}
170
245
171
246
void Conv1dXPUInferMeta (const MetaTensor& x,
0 commit comments