diff --git a/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc index e1900e01761b9..8e126df64ad41 100644 --- a/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc @@ -37,7 +37,8 @@ struct SingleEncoderXPUPattern : public PatternBase { const std::string& matmul_type_2, bool norm_before, bool with_q_scale, - bool with_mask); + bool with_mask, + bool is_smooth_quant); // declare operator node's name // If norm_before, use ln_0 & ln_1. @@ -132,6 +133,13 @@ struct SingleEncoderXPUPattern : public PatternBase { PATTERN_DECL_NODE(ln_2_out); PATTERN_DECL_NODE(ln_2_mean); PATTERN_DECL_NODE(ln_2_variance); + // smooth quant + PATTERN_DECL_NODE(smooth_scale_1); + PATTERN_DECL_NODE(smooth_scale_2); + PATTERN_DECL_NODE(smooth_scale_1_weight); + PATTERN_DECL_NODE(smooth_scale_2_weight); + PATTERN_DECL_NODE(smooth_scale_1_out); + PATTERN_DECL_NODE(smooth_scale_2_out); private: std::string act_type_; @@ -141,6 +149,7 @@ struct SingleEncoderXPUPattern : public PatternBase { bool norm_before_{false}; bool with_q_scale_{false}; bool with_mask_{true}; + bool is_smooth_quant_{false}; }; SingleEncoderXPUPattern::SingleEncoderXPUPattern( @@ -152,7 +161,8 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern( const std::string& matmul_type_2, bool norm_before, bool with_q_scale, - bool with_mask) + bool with_mask, + bool is_smooth_quant) : PatternBase(pattern, name_scope, name_scope), act_type_(act_type), matmul_type_0_(matmul_type_0), @@ -160,7 +170,8 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern( matmul_type_2_(matmul_type_2), norm_before_(norm_before), with_q_scale_(with_q_scale), - with_mask_(with_mask) { + with_mask_(with_mask), + is_smooth_quant_(is_smooth_quant) { // layer_norm 0 PDNode* ln_0_x = pattern->NewNode(ln_0_x_repr()); PDNode* ln_0_bias = nullptr; @@ -169,6 +180,9 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern( PDNode* ln_0_out = nullptr; PDNode* ln_0_mean = nullptr; PDNode* ln_0_variance = nullptr; + PDNode* smooth_scale_1_weight = nullptr; + PDNode* smooth_scale_1 = nullptr; + PDNode* smooth_scale_1_out = nullptr; if (norm_before_) { ln_0_x->assert_is_op_input("layer_norm", "X")->assert_var_not_persistable(); ln_0_bias = pattern->NewNode(ln_0_bias_repr()) @@ -188,6 +202,19 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern( ->assert_is_op_output("layer_norm", "Variance") ->assert_var_not_persistable(); } + if (!norm_before_ && is_smooth_quant_) { + VLOG(3) << "build first smooth_quant_scale"; + ln_0_x->assert_is_op_input("elementwise_mul", "X") + ->assert_var_not_persistable(); + smooth_scale_1_weight = pattern->NewNode(smooth_scale_1_weight_repr()) + ->assert_is_op_input("elementwise_mul", "Y") + ->assert_is_persistable_var(); + smooth_scale_1 = pattern->NewNode(smooth_scale_1_repr()) + ->assert_is_op("elementwise_mul"); + smooth_scale_1_out = pattern->NewNode(smooth_scale_1_out_repr()) + ->assert_is_op_output("elementwise_mul", "Out") + ->assert_var_not_persistable(); + } // q: matmul + add + reshape + transpose auto q_matmul_w = pattern->NewNode(q_matmul_w_repr()) @@ -362,6 +389,22 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern( auto* ln_1_variance = pattern->NewNode(ln_1_variance_repr()) ->assert_is_op_output("layer_norm", "Variance") ->assert_var_not_persistable(); + PDNode* smooth_scale_2_weight = nullptr; + PDNode* smooth_scale_2 = nullptr; + PDNode* smooth_scale_2_out = nullptr; + if (!norm_before_ && is_smooth_quant_) { + VLOG(3) << "build second smooth_quant_scale"; + ln_1_out->assert_is_op_input("elementwise_mul", "X") + ->assert_var_not_persistable(); + smooth_scale_2_weight = pattern->NewNode(smooth_scale_2_weight_repr()) + ->assert_is_op_input("elementwise_mul", "Y") + ->assert_is_persistable_var(); + smooth_scale_2 = pattern->NewNode(smooth_scale_2_repr()) + ->assert_is_op("elementwise_mul"); + smooth_scale_2_out = pattern->NewNode(smooth_scale_2_out_repr()) + ->assert_is_op_output("elementwise_mul", "Out") + ->assert_var_not_persistable(); + } auto qkv_matmul_2_w = pattern->NewNode(qkv_matmul_2_w_repr()) ->assert_is_op_input(matmul_type_0_, "Y") ->assert_is_persistable_var(); @@ -471,7 +514,15 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern( .LinksTo({qkv_matmul_1_out}); qkv_add_0->LinksFrom({qkv_matmul_1_out, qkv_add_0_bias}) .LinksTo({qkv_add_0_out}); - qkv_add_1->LinksFrom({qkv_add_0_out, q_matmul_x}).LinksTo({qkv_add_1_out}); + if (!norm_before_ && is_smooth_quant_) { + smooth_scale_1->LinksFrom({q_matmul_x, smooth_scale_1_weight}) + .LinksTo({smooth_scale_1_out}); + qkv_add_1->LinksFrom({qkv_add_0_out, smooth_scale_1_out}) + .LinksTo({qkv_add_1_out}); + } else { + qkv_add_1->LinksFrom({qkv_add_0_out, q_matmul_x}).LinksTo({qkv_add_1_out}); + } + ln_1->LinksFrom({qkv_add_1_out, ln_1_bias, ln_1_scale}) .LinksTo({ln_1_out, ln_1_mean, ln_1_variance}); qkv_matmul_2->LinksFrom({ln_1_out, qkv_matmul_2_w}) @@ -487,7 +538,14 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern( qkv_add_4->LinksFrom({qkv_add_3_out, qkv_add_1_out}) .LinksTo({qkv_add_4_out}); } else { - qkv_add_4->LinksFrom({qkv_add_3_out, ln_1_out}).LinksTo({qkv_add_4_out}); + if (is_smooth_quant_) { + smooth_scale_2->LinksFrom({ln_1_out, smooth_scale_2_weight}) + .LinksTo({smooth_scale_2_out}); + qkv_add_4->LinksFrom({qkv_add_3_out, smooth_scale_2_out}) + .LinksTo({qkv_add_4_out}); + } else { + qkv_add_4->LinksFrom({qkv_add_3_out, ln_1_out}).LinksTo({qkv_add_4_out}); + } ln_2->LinksFrom({qkv_add_4_out, ln_2_bias, ln_2_scale}) .LinksTo({ln_2_out, ln_2_mean, ln_2_variance}); } @@ -512,7 +570,8 @@ void MultiEncoderXPUFusePass::ApplyImpl(ir::Graph* graph) const { pattern_param.matmul_type_2, pattern_param.norm_before, pattern_param.with_q_scale, - pattern_param.with_mask); + pattern_param.with_mask, + pattern_param.is_smooth_quant); while (ApplyMultiEncoderXPUFuse(graph)) { multi_encoder_fused_counts++; } @@ -530,7 +589,9 @@ void MultiEncoderXPUFusePass::PrepareInputMax( BlockDesc* block, std::unordered_map>* node_maps, std::unordered_map>* var_quant_scales, - std::vector* input_max_nodes) const { + std::vector* input_max_nodes, + std::vector* quant_types, + const std::string* act_type) const { // mul input_max, output_max * 6 + matmul x_max,y_max,output_max * 2 auto quant_mul_ops = node_maps->find("quant_mul_ops")->second; auto matmul_ops = node_maps->find("matmul_ops")->second; @@ -541,12 +602,18 @@ void MultiEncoderXPUFusePass::PrepareInputMax( for (size_t i = 0; i < quant_mul_ops.size(); ++i) { auto input_name = quant_mul_ops[i]->Op()->Input("X")[0]; auto output_name = mul_add_ops[i]->Op()->Output("Out")[0]; - if (var_quant_scales->find(input_name) != var_quant_scales->end() && + if (quant_mul_ops[i]->Op()->HasAttr("enable_int8") && + PADDLE_GET_CONST(bool, + quant_mul_ops[i]->Op()->GetAttr("enable_int8")) && + var_quant_scales->find(input_name) != var_quant_scales->end() && var_quant_scales->find(output_name) != var_quant_scales->end()) { input_max[i * 2] = var_quant_scales->at(input_name)[0]; input_max[i * 2 + 1] = var_quant_scales->at(output_name)[0]; VLOG(3) << quant_mul_ops[i] << " input_max: " << input_max[i * 2] << ", output_max(ew_add): " << input_max[i * 2 + 1]; + quant_types->push_back("enable_int8"); + } else { + quant_types->push_back("not_quantized"); } } float max_qkv_input = std::max(input_max[0], input_max[2]); @@ -562,6 +629,22 @@ void MultiEncoderXPUFusePass::PrepareInputMax( VLOG(3) << "max_qkv_input: " << max_qkv_input << ", max_qkv_output: " << max_qkv_output; + if (*act_type == "gelu") { + // use gelu10 according to whitepaper http://arxiv.org/abs/2004.09602 + float gelu_out_threshold = 10.f; + if (std::getenv("QUANT_GELU_OUT_THRESHOLD")) { + gelu_out_threshold = atof(std::getenv("QUANT_GELU_OUT_THRESHOLD")); + PADDLE_ENFORCE_GT( + gelu_out_threshold, + 0.f, + phi::errors::InvalidArgument( + "QUANT_GELU_OUT_THRESHOLD should be an positive float value: %f", + gelu_out_threshold)); + } + input_max[9] = std::min(gelu_out_threshold, input_max[9]); + input_max[10] = std::min(gelu_out_threshold, input_max[10]); + } + auto input_x_name = matmul_ops[0]->Op()->Input("X")[0]; auto input_y_name = matmul_ops[0]->Op()->Input("Y")[0]; auto output_name = softmax_ops[0]->Op()->Output("Out")[0]; @@ -583,6 +666,9 @@ void MultiEncoderXPUFusePass::PrepareInputMax( << " Y_max: " << input_max[matmul_offset * 2 + 1] << " Out_max: " << input_max[matmul_offset * 2 + 2]; matmul_quants[0] = true; + quant_types->push_back("enable_int8"); + } else { + quant_types->push_back("not_quantized"); } input_x_name = matmul_ops[1]->Op()->Input("X")[0]; input_y_name = matmul_ops[1]->Op()->Input("Y")[0]; @@ -599,8 +685,11 @@ void MultiEncoderXPUFusePass::PrepareInputMax( << " Y_max: " << input_max[matmul_offset * 2 + 4] << " Out_max: " << input_max[matmul_offset * 2 + 5]; matmul_quants[1] = true; + quant_types->push_back("enable_int8"); + } else { + quant_types->push_back("not_quantized"); } - if (matmul_quants[0] == false || matmul_quants[1] == false) { + if (matmul_quants[0] == false && matmul_quants[1] == false) { // For backward compatible, API uses the size of input_max vector to // check whether it is mul quant or mul+matmul quant. input_max.resize(quant_mul_ops.size() * 2); @@ -644,6 +733,7 @@ void MultiEncoderXPUFusePass::PrepareQKVWeight( Node* k_w, Node* v_w, bool enable_int8, + bool local_quant, std::unordered_map>* var_quant_scales, Node** qkv_w_intx, Node** qkv_w_max, @@ -676,9 +766,14 @@ void MultiEncoderXPUFusePass::PrepareQKVWeight( CastToFp32(&k_w_t); CastToFp32(&v_w_t); phi::ConcatKernel(*cpu_ctx, in_tensors, 0, &qkv_w_intx_t); - ConvertWithQuant( - &qkv_w_intx_t, &qkv_w_max_t, &qkv_scale_max_t, false); - qkv_w_intx_hash = HashTensor(qkv_w_intx_t); + if (local_quant) { + qkv_w_intx_hash = HashTensor(qkv_w_intx_t); + ConvertFromFp32ToFp16(&qkv_w_intx_t, &qkv_w_max_t, false); + } else { + ConvertWithQuant( + &qkv_w_intx_t, &qkv_w_max_t, &qkv_scale_max_t, false); + qkv_w_intx_hash = HashTensor(qkv_w_intx_t); + } } else { phi::ConcatKernel(*cpu_ctx, in_tensors, 0, &qkv_w_intx_t); std::unordered_map> var_quant_scales = @@ -854,7 +949,12 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse( const std::string& matmul_type_2, bool norm_before, bool with_q_scale, - bool with_mask) const { + bool with_mask, + bool is_smooth_quant) const { + bool local_quant = false; + if (std::getenv("XPU_LOCAL_QUANT")) { + local_quant = atoi(std::getenv("XPU_LOCAL_QUANT")); + } GraphPatternDetector gpd; patterns::SingleEncoderXPUPattern pattern(gpd.mutable_pattern(), name_scope_, @@ -864,7 +964,8 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse( matmul_type_2, norm_before, with_q_scale, - with_mask); + with_mask, + is_smooth_quant); int found_subgraph_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, @@ -959,6 +1060,13 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse( GET_IR_NODE(ln_2_out); GET_IR_NODE(ln_2_mean); GET_IR_NODE(ln_2_variance); + // smooth quant + GET_IR_NODE(smooth_scale_1); + GET_IR_NODE(smooth_scale_2); + GET_IR_NODE(smooth_scale_1_weight); + GET_IR_NODE(smooth_scale_2_weight); + GET_IR_NODE(smooth_scale_1_out); + GET_IR_NODE(smooth_scale_2_out); auto* block = q_matmul->Op()->Block(); auto* scope = param_scope(); @@ -974,6 +1082,8 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse( std::vector input_max; std::unordered_map> var_quant_scales = GetQuantInfoFromTheGraph(graph, "has_quant_info", "var_quant_scales"); + std::vector quant_types; + std::vector new_add_nodes; if (use_precision == "int8") { std::unordered_map> node_maps; std::vector quant_mul_ops = {q_matmul, @@ -992,10 +1102,15 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse( node_maps.insert(std::make_pair("softmax_ops", softmax_ops)); input_max.resize((quant_mul_ops.size() * 2 + quant_mul_ops.size() * 3), nullptr); - PrepareInputMax( - graph, scope, block, &node_maps, &var_quant_scales, &input_max); + PrepareInputMax(graph, + scope, + block, + &node_maps, + &var_quant_scales, + &input_max, + &quant_types, + &act_type); } - Node* qkv_w_intx = nullptr; Node* qkv_w_max = nullptr; Node* qkv_scale_max = nullptr; @@ -1006,39 +1121,60 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse( k_matmul_w, v_matmul_w, (use_precision == "int8"), + local_quant, &var_quant_scales, &qkv_w_intx, &qkv_w_max, &qkv_scale_max); + new_add_nodes.push_back(qkv_w_intx); + new_add_nodes.push_back(qkv_w_max); + new_add_nodes.push_back(qkv_scale_max); + +#define PREPARE_QKV_MATMUL_W(idx_) \ + Node* qkv_matmul_##idx_##_w_intx = nullptr; \ + Node* qkv_matmul_##idx_##_w_max = nullptr; \ + Node* qkv_matmul_##idx_##_scale_max = nullptr; \ + if (var_quant_scales.find(qkv_matmul_##idx_##_w->Name()) == \ + var_quant_scales.end()) { \ + if (local_quant) { \ + PrepareWeight(graph, \ + scope, \ + block, \ + qkv_matmul_##idx_##_w, \ + &qkv_matmul_##idx_##_w_intx, \ + &qkv_matmul_##idx_##_w_max, \ + &qkv_matmul_##idx_##_scale_max, \ + true, \ + std::vector({})); \ + } else { \ + PrepareWeight(graph, \ + scope, \ + block, \ + qkv_matmul_##idx_##_w, \ + &qkv_matmul_##idx_##_w_intx, \ + &qkv_matmul_##idx_##_w_max, \ + &qkv_matmul_##idx_##_scale_max, \ + true, \ + std::vector({})); \ + } \ + } else { \ + std::vector weight_scales = \ + var_quant_scales.at(qkv_matmul_##idx_##_w->Name()); \ + is_per_channel = (weight_scales.size() != 1); \ + PrepareWeight(graph, \ + scope, \ + block, \ + qkv_matmul_##idx_##_w, \ + &qkv_matmul_##idx_##_w_intx, \ + &qkv_matmul_##idx_##_w_max, \ + &qkv_matmul_##idx_##_scale_max, \ + true, \ + weight_scales); \ + } \ + new_add_nodes.push_back(qkv_matmul_##idx_##_w_intx); \ + new_add_nodes.push_back(qkv_matmul_##idx_##_w_max); \ + new_add_nodes.push_back(qkv_matmul_##idx_##_scale_max); -#define PREPARE_QKV_MATMUL_W(idx_) \ - Node* qkv_matmul_##idx_##_w_intx = nullptr; \ - Node* qkv_matmul_##idx_##_w_max = nullptr; \ - Node* qkv_matmul_##idx_##_scale_max = nullptr; \ - if (use_precision != "int8") { \ - PrepareWeight(graph, \ - scope, \ - block, \ - qkv_matmul_##idx_##_w, \ - &qkv_matmul_##idx_##_w_intx, \ - &qkv_matmul_##idx_##_w_max, \ - &qkv_matmul_##idx_##_scale_max, \ - true, \ - std::vector({})); \ - } else { \ - std::vector weight_scales = \ - var_quant_scales.at(qkv_matmul_##idx_##_w->Name()); \ - is_per_channel = (weight_scales.size() != 1); \ - PrepareWeight(graph, \ - scope, \ - block, \ - qkv_matmul_##idx_##_w, \ - &qkv_matmul_##idx_##_w_intx, \ - &qkv_matmul_##idx_##_w_max, \ - &qkv_matmul_##idx_##_scale_max, \ - true, \ - weight_scales); \ - } PREPARE_QKV_MATMUL_W(1); PREPARE_QKV_MATMUL_W(2); PREPARE_QKV_MATMUL_W(3); @@ -1071,24 +1207,33 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse( } } op_desc.SetInput("fc_input_max", fc_input_max_names); - op_desc.SetInput("fc_weight", - {qkv_w_intx->Name(), - qkv_matmul_1_w_intx->Name(), - qkv_matmul_2_w_intx->Name(), - qkv_matmul_3_w_intx->Name()}); - if (!is_per_channel) { - op_desc.SetInput("fc_weight_max", - {qkv_w_max->Name(), - qkv_matmul_1_w_max->Name(), - qkv_matmul_2_w_max->Name(), - qkv_matmul_3_w_max->Name()}); - } else { - op_desc.SetInput("fc_weight_max", - {qkv_scale_max->Name(), - qkv_matmul_1_scale_max->Name(), - qkv_matmul_2_scale_max->Name(), - qkv_matmul_3_scale_max->Name()}); + std::vector fc_weight_nodes = { + qkv_matmul_1_w_intx, qkv_matmul_2_w_intx, qkv_matmul_3_w_intx}; + std::vector fc_weight_names; + fc_weight_names.push_back(qkv_w_intx->Name()); + for (size_t i = 0; i < fc_weight_nodes.size(); i++) { + if (fc_weight_nodes[i]) { + fc_weight_names.push_back(fc_weight_nodes[i]->Name()); + } + } + op_desc.SetInput("fc_weight", fc_weight_names); + std::vector fc_weight_max_names; + + std::vector fc_weight_max_nodes = { + qkv_w_max, qkv_matmul_1_w_max, qkv_matmul_2_w_max, qkv_matmul_3_w_max}; + std::vector fc_weight_sacle_nodes = {qkv_scale_max, + qkv_matmul_1_scale_max, + qkv_matmul_2_scale_max, + qkv_matmul_3_scale_max}; + + for (size_t i = 0; i < fc_weight_sacle_nodes.size(); i++) { + if (fc_weight_sacle_nodes[i]) { + fc_weight_max_names.push_back(fc_weight_sacle_nodes[i]->Name()); + } else { + fc_weight_max_names.push_back(fc_weight_max_nodes[i]->Name()); + } } + op_desc.SetInput("fc_weight_max", fc_weight_max_names); op_desc.SetInput("fc_bias", {qkv_add_bias_fp32->Name(), qkv_add_0_bias_fp32->Name(), @@ -1104,6 +1249,17 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse( if (with_mask) { op_desc.SetInput("mask", {qk_add_mask->Name()}); } + std::vector smooth_scale_names; + if (is_smooth_quant) { + smooth_scale_names.push_back(smooth_scale_1_weight->Name()); + smooth_scale_names.push_back(smooth_scale_2_weight->Name()); + for (auto smooth_scale_name : smooth_scale_names) { + auto* in = + scope->FindVar(smooth_scale_name)->GetMutable(); + CastToFp16(in); + } + } + op_desc.SetInput("smooth_scale_weight", {smooth_scale_names}); op_desc.SetAttr("norm_before", norm_before); op_desc.SetAttr("hidden_dim", static_cast(q_matmul_w->Var()->GetShape()[0])); @@ -1119,6 +1275,13 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse( op_desc.SetAttr("relative_type", static_cast(0)); op_desc.SetAttr("use_precision", use_precision); op_desc.SetAttr("is_per_channel", is_per_channel); + // if quant,skip softmax,and use qk_matmul out_threshold as softmax_max + auto softmax_max_name = qk_matmul->Op()->Output("Out")[0]; + if (var_quant_scales.find(softmax_max_name) != var_quant_scales.end()) { + op_desc.SetAttr("softmax_max_value", + var_quant_scales.at(softmax_max_name)[0]); + } + op_desc.SetAttr("quant_types", quant_types); if (norm_before) { op_desc.SetOutput("out", {qkv_add_4_out->Name()}); } else { @@ -1132,20 +1295,11 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse( IR_NODE_LINK_TO(node, single_encoder_xpu); } } - if (is_per_channel) { - IR_NODE_LINK_TO(qkv_scale_max, single_encoder_xpu); - IR_NODE_LINK_TO(qkv_matmul_1_scale_max, single_encoder_xpu); - IR_NODE_LINK_TO(qkv_matmul_2_scale_max, single_encoder_xpu); - IR_NODE_LINK_TO(qkv_matmul_3_scale_max, single_encoder_xpu); + for (auto* node : new_add_nodes) { + if (node) { + IR_NODE_LINK_TO(node, single_encoder_xpu); + } } - IR_NODE_LINK_TO(qkv_w_intx, single_encoder_xpu); - IR_NODE_LINK_TO(qkv_w_max, single_encoder_xpu); - IR_NODE_LINK_TO(qkv_matmul_1_w_intx, single_encoder_xpu); - IR_NODE_LINK_TO(qkv_matmul_1_w_max, single_encoder_xpu); - IR_NODE_LINK_TO(qkv_matmul_2_w_intx, single_encoder_xpu); - IR_NODE_LINK_TO(qkv_matmul_2_w_max, single_encoder_xpu); - IR_NODE_LINK_TO(qkv_matmul_3_w_intx, single_encoder_xpu); - IR_NODE_LINK_TO(qkv_matmul_3_w_max, single_encoder_xpu); IR_NODE_LINK_TO(qkv_add_bias_fp32, single_encoder_xpu); IR_NODE_LINK_TO(qkv_add_0_bias_fp32, single_encoder_xpu); IR_NODE_LINK_TO(qkv_add_2_bias_fp32, single_encoder_xpu); @@ -1162,6 +1316,10 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse( } else { IR_NODE_LINK_TO(single_encoder_xpu, ln_2_out); } + if (is_smooth_quant) { + IR_NODE_LINK_TO(smooth_scale_1_weight, single_encoder_xpu); + IR_NODE_LINK_TO(smooth_scale_2_weight, single_encoder_xpu); + } // Delete nodes std::unordered_set delete_nodes{ln_1, @@ -1241,6 +1399,12 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse( delete_nodes.insert(qk_add); delete_nodes.insert(qk_add_out); } + if (is_smooth_quant) { + delete_nodes.insert(smooth_scale_1); + delete_nodes.insert(smooth_scale_2); + delete_nodes.insert(smooth_scale_1_out); + delete_nodes.insert(smooth_scale_2_out); + } GraphSafeRemoveNodes(graph, delete_nodes); found_subgraph_count++; }; @@ -1288,7 +1452,8 @@ bool MultiEncoderXPUFusePass::ApplyMultiEncoderXPUFuse(ir::Graph* graph) const { "fc_weight_max", "fc_bias", "ln_scale", - "ln_bias"}; + "ln_bias", + "smooth_scale_weight"}; std::map> arg_names_map; std::string mask_name = single_encoders[0]->Op()->Inputs().count("mask") > 0 ? single_encoders[0]->Op()->Inputs().at("mask")[0] @@ -1375,6 +1540,22 @@ bool MultiEncoderXPUFusePass::ApplyMultiEncoderXPUFuse(ir::Graph* graph) const { "is_per_channel", PADDLE_GET_CONST(bool, single_encoders[0]->Op()->GetAttr("is_per_channel"))); + std::vector softmax_max_values; + for (auto* single_encoder : single_encoders) { + if (single_encoder->Op()->HasAttr("softmax_max_value")) { + softmax_max_values.push_back(PADDLE_GET_CONST( + float, single_encoder->Op()->GetAttr("softmax_max_value"))); + } + } + op_desc.SetAttr("softmax_max_value", softmax_max_values); + std::vector quant_types; + for (auto* single_encoder : single_encoders) { + auto per_quant_types = PADDLE_GET_CONST( + std::vector, single_encoder->Op()->GetAttr("quant_types")); + quant_types.insert( + quant_types.end(), per_quant_types.begin(), per_quant_types.end()); + } + op_desc.SetAttr("quant_types", quant_types); op_desc.SetOutput("out", {out_name}); op_desc.SetOutput("x_fp16", {x_fp16_name}); op_desc.SetOutput("out_fp16", {out_fp16_name}); @@ -1411,8 +1592,7 @@ int MultiEncoderXPUFusePass::CastMask(ir::Graph* graph) const { auto use_precision = op_desc->GetAttrIfExists("use_precision"); if (node->IsVar() || // op_desc->Type() != "multi_encoder_xpu" || - (use_precision != "float16" && use_precision != "int8") || - op_desc->Inputs().count("mask") == 0) + (use_precision != "float16") || op_desc->Inputs().count("mask") == 0) continue; auto* block = op_desc->Block(); @@ -1462,10 +1642,15 @@ std::vector MultiEncoderXPUFusePass::GeneratePatternParams() const { return std::vector{ // Params are arranged in alphabetic order - {"gelu", "matmul_v2", "matmul", "matmul_v2", false, false, true}, - {"gelu", "matmul_v2", "matmul_v2", "matmul_v2", false, true, true}, - {"gelu", "mul", "matmul", "matmul", false, true, true}, - {"relu", "mul", "matmul", "matmul", false, true, true}, + {"gelu", "matmul_v2", "matmul", "matmul_v2", false, false, true, false}, + {"gelu", "matmul_v2", "matmul_v2", "matmul_v2", false, true, true, false}, + {"gelu", "mul", "matmul", "matmul", false, true, true, false}, + {"relu", "mul", "matmul", "matmul", false, true, true, false}, + + {"gelu", "matmul_v2", "matmul", "matmul_v2", false, false, true, true}, + {"gelu", "matmul_v2", "matmul_v2", "matmul_v2", false, true, true, true}, + {"gelu", "mul", "matmul", "matmul", false, true, true, true}, + {"relu", "mul", "matmul", "matmul", false, true, true, true}, }; } diff --git a/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.h b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.h index 41f3aaf44fdf5..cb24ea8128451 100644 --- a/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.h +++ b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.h @@ -128,6 +128,7 @@ struct PatternParam { bool norm_before; bool with_q_scale; bool with_mask; + bool is_smooth_quant; }; class MultiEncoderXPUFusePass : public FusePassBase { @@ -142,7 +143,8 @@ class MultiEncoderXPUFusePass : public FusePassBase { const std::string& matmul_type_2, bool norm_before, bool with_q_scale, - bool with_mask) const; + bool with_mask, + bool is_smooth_qunat) const; bool ApplyMultiEncoderXPUFuse(ir::Graph* graph) const; @@ -152,7 +154,7 @@ class MultiEncoderXPUFusePass : public FusePassBase { // 1. Transpose q_w, k_w, v_w // 2. Concat q_w, k_w, v_w // 3. Generate qkv_w_max tensor - // 4. Quant qkv_w to int16 + // 4. Quant qkv_w to int16/int8 or cast to float16 (local quant) void PrepareQKVWeight( Graph* graph, Scope* scope, @@ -161,6 +163,7 @@ class MultiEncoderXPUFusePass : public FusePassBase { Node* k_w, Node* v_w, bool enable_int8, + bool local_quant, std::unordered_map>* var_quant_scales, Node** qkv_w, Node** qkv_w_max, @@ -171,7 +174,9 @@ class MultiEncoderXPUFusePass : public FusePassBase { BlockDesc* block, std::unordered_map>* node_maps, std::unordered_map>* var_quant_scales, - std::vector* input_max_nodes) const; + std::vector* input_max_nodes, + std::vector* quant_types, + const std::string* act_type) const; // 1. Cast bias to fp32 // 2. Concat q/k/v bias diff --git a/paddle/fluid/framework/ir/xpu/pass_utils.cc b/paddle/fluid/framework/ir/xpu/pass_utils.cc index f9844bc813457..b0853690c065a 100644 --- a/paddle/fluid/framework/ir/xpu/pass_utils.cc +++ b/paddle/fluid/framework/ir/xpu/pass_utils.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/framework/ir/xpu/pass_utils.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/kernels/cast_kernel.h" #include "paddle/phi/kernels/xpu/xpu_api_wrapper.h" namespace paddle { @@ -123,11 +124,68 @@ template size_t HashTensor(const phi::DenseTensor& in); template size_t HashTensor(const phi::DenseTensor& in); template size_t HashTensor(const phi::DenseTensor& in); +template <> +size_t HashTensor(const phi::DenseTensor& in) { + phi::DenseTensor dst_tensor; + auto* cpu_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(phi::CPUPlace())); + dst_tensor.Resize(in.dims()); + dst_tensor.set_type(phi::DataType::FLOAT32); + dst_tensor.set_layout(in.layout()); + phi::CastKernel(*cpu_ctx, in, phi::DataType::FLOAT32, &dst_tensor); + return HashTensor(dst_tensor); +} + std::string GetPrefixWithoutHash(const std::string& name) { std::size_t found = name.find("_#"); return found == std::string::npos ? name : name.substr(0, found); } +void ConvertFromFp32ToFp16(phi::DenseTensor* weight, + phi::DenseTensor* weight_max, + bool transpose) { + // Convert fp16 to fp32 + phi::DenseTensor weight_fp32; + CastToFp32(weight, &weight_fp32); + + if (transpose) { // (k, n) -> (n, k) + Transpose2D(&weight_fp32); + } + + auto FindMaxAbs = [](const float* data, int len) { + float max_f = 0.0f; + for (int i = 0; i < len; ++i) { + float max = std::abs(data[i]); + if (max > max_f) { + max_f = max; + } + } + return max_f; + }; + + auto* cpu_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(phi::CPUPlace())); + // Convert to fp16 + phi::DenseTensor weight_fp16; + CastToFp16(&weight_fp32, &weight_fp16); + // Find max + int max_ptr_size = phi::backends::xpu::get_xpu_max_ptr_size(-1); + int size = weight_fp32.numel(); + float max_val = FindMaxAbs(weight_fp32.data(), size); + std::vector max_vec(max_ptr_size, max_val); + weight_max->set_type(phi::DataType::FLOAT32); + weight_max->Resize({max_ptr_size}); + memcpy(cpu_ctx->Alloc(weight_max), + max_vec.data(), + max_ptr_size * sizeof(float)); + weight->clear(); + weight->set_type(phi::DataType::FLOAT16); + weight->Resize({size}); + memcpy(cpu_ctx->Alloc(weight), + weight_fp16.data(), + size * sizeof(float16)); +} + template void PrepareWeight(Graph* graph, Scope* scope, @@ -268,6 +326,18 @@ template void PrepareWeight( const std::vector& weight_scales, bool per_channel_quant = false); +template void PrepareWeight( + Graph* graph, + Scope* scope, + BlockDesc* block, + Node* weight, + Node** dst_weight, + Node** dst_weight_max, + Node** dst_scale_max, + bool transpose, + const std::vector& weight_scales, + bool per_channel_quant = false); + template void PrepareWeight( Graph* graph, Scope* scope, diff --git a/paddle/fluid/framework/ir/xpu/pass_utils.h b/paddle/fluid/framework/ir/xpu/pass_utils.h index 93d8f2860af9e..994188f1a28f8 100644 --- a/paddle/fluid/framework/ir/xpu/pass_utils.h +++ b/paddle/fluid/framework/ir/xpu/pass_utils.h @@ -57,6 +57,10 @@ std::vector FindOpNodeByInputName(Graph* graph, template size_t HashTensor(const phi::DenseTensor& in); +void ConvertFromFp32ToFp16(phi::DenseTensor* weight, + phi::DenseTensor* weight_max, + bool transpose); + template ::value, Tcpu>::type* @@ -67,8 +71,12 @@ void ConvertWeightWrapper(phi::DenseTensor* weight, bool transpose, const std::vector& weight_scales, bool per_channel_quant) { - ConvertWithQuant( - weight, weight_max, scale_max, transpose, per_channel_quant); + if (std::is_same::value && std::is_same::value) { + ConvertFromFp32ToFp16(weight, weight_max, transpose); + } else { + ConvertWithQuant( + weight, weight_max, scale_max, transpose, per_channel_quant); + } } template Name() == out_var_name) { for (auto* var : out->outputs) { auto op_desc = var->Op(); - std::string quantized_op_type = op_desc->Type(); op_desc->SetAttr("enable_int8", true); op_desc->Flush(); } diff --git a/paddle/fluid/framework/ir/xpu/quant_utils.cc b/paddle/fluid/framework/ir/xpu/quant_utils.cc index 113e2ec0fe080..cdefbb5ca682c 100644 --- a/paddle/fluid/framework/ir/xpu/quant_utils.cc +++ b/paddle/fluid/framework/ir/xpu/quant_utils.cc @@ -115,41 +115,47 @@ void CastToInt32(phi::DenseTensor* in, phi::DenseTensor* out) { Assign(*out_ptr, in); } } - -void CastToFp32(phi::DenseTensor* in, phi::DenseTensor* out) { +void CastTo(phi::DenseTensor* in, phi::DenseTensor* out, DataType out_dtype) { auto* cpu_ctx = static_cast( platform::DeviceContextPool::Instance().Get(phi::CPUPlace())); - paddle::experimental::CheckAndTrans2Contiguous(in); + if (in->dtype() != phi::DataType::FLOAT16 && + in->dtype() != phi::DataType::FLOAT32) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Only support fp16 and fp32, but received dtype is %s.", + phi::DataTypeToString(in->dtype()))); + } - phi::DenseTensor fp32_tensor; - phi::DenseTensor* out_ptr = out == nullptr ? &fp32_tensor : out; + paddle::experimental::CheckAndTrans2Contiguous(in); + phi::DenseTensor ori_tensor; + phi::DenseTensor* out_ptr = out == nullptr ? &ori_tensor : out; out_ptr->Resize(in->dims()); - out_ptr->set_type(phi::DataType::FLOAT32); + out_ptr->set_type(out_dtype); out_ptr->set_layout(in->layout()); - - switch (in->dtype()) { - case phi::DataType::FLOAT16: - phi::CastKernel( - *cpu_ctx, *in, phi::DataType::FLOAT32, out_ptr); - break; - case phi::DataType::FLOAT32: - if (out == nullptr) { - return; - } else { - phi::AssignKernel(*cpu_ctx, *in, out_ptr); - } - break; - default: - PADDLE_THROW(platform::errors::InvalidArgument( - "Only support fp16 and fp32, but received dtype is %s.", - phi::DataTypeToString(in->dtype()))); - break; + if (in->dtype() == out_dtype) { + if (out == nullptr) { + return; + } else { + phi::AssignKernel(*cpu_ctx, *in, out_ptr); + } + } else { + if (in->dtype() == phi::DataType::FLOAT16) { + phi::CastKernel(*cpu_ctx, *in, out_dtype, out_ptr); + } else { + phi::CastKernel(*cpu_ctx, *in, out_dtype, out_ptr); + } + if (out == nullptr) { + Assign(*out_ptr, in); + } } +} - if (out == nullptr) { - Assign(*out_ptr, in); - } +void CastToFp32(phi::DenseTensor* in, phi::DenseTensor* out) { + CastTo(in, out, phi::DataType::FLOAT32); +} + +void CastToFp16(phi::DenseTensor* in, phi::DenseTensor* out) { + CastTo(in, out, phi::DataType::FLOAT16); } static float FindMaxAbs(const float* data, int len) { diff --git a/paddle/fluid/framework/ir/xpu/quant_utils.h b/paddle/fluid/framework/ir/xpu/quant_utils.h index 6e6054acd0263..1366d69ec1f47 100644 --- a/paddle/fluid/framework/ir/xpu/quant_utils.h +++ b/paddle/fluid/framework/ir/xpu/quant_utils.h @@ -23,8 +23,12 @@ void Assign(const phi::DenseTensor& in, phi::DenseTensor* out); void Transpose2D(phi::DenseTensor* in, phi::DenseTensor* out = nullptr); +void CastTo(phi::DenseTensor* in, phi::DenseTensor* out, DataType dtype); + void CastToFp32(phi::DenseTensor* in, phi::DenseTensor* out = nullptr); +void CastToFp16(phi::DenseTensor* in, phi::DenseTensor* out = nullptr); + void CastToInt32(phi::DenseTensor* in, phi::DenseTensor* out = nullptr); template diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 47312ceb1d7c3..0684064df81e8 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -521,7 +521,7 @@ void CpuPassStrategy::EraseFcMkldnnPasses() { XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { passes_.assign({ - "quant_dequant_xpu_pass", + // "quant_dequant_xpu_pass", open this pass when use old int8 model "delete_quant_dequant_linear_op_pass", "delete_weight_dequant_linear_op_pass", "delete_assign_op_pass", diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index 0dc6aad3d8c43..f8dcb02cbdc72 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -399,7 +399,7 @@ backward : max_pool2d_v2_grad - op : multi_encoder_xpu - args : (Tensor x, Tensor[] fc_input_max, Tensor[] fc_weight, Tensor[] fc_weight_max, Tensor[] fc_bias, Tensor[] ln_scale, Tensor[] ln_bias, Tensor mask, Tensor seq_lod, Tensor max_seq_len, int layer_num, bool norm_before, int hidden_dim, int head_num, int size_per_head, int ffn_hidden_dim_scale, int act_type, int relative_type, int slice_idx, bool is_per_channel) + args : (Tensor x, Tensor[] fc_input_max, Tensor[] fc_weight, Tensor[] fc_weight_max, Tensor[] fc_bias, Tensor[] ln_scale, Tensor[] ln_bias, Tensor[] smooth_scale_weight, Tensor mask, Tensor seq_lod, Tensor max_seq_len, int layer_num, bool norm_before, int hidden_dim, int head_num, int size_per_head, int ffn_hidden_dim_scale, int act_type, int relative_type, int slice_idx, bool is_per_channel, float[] softmax_max_value, str[] quant_types) output : Tensor(out), Tensor(x_fp16), Tensor(out_fp16) infer_meta : func : MultiEncoderXPUInferMeta diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index 2275727b51019..e6e0082f626f0 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -1446,6 +1446,7 @@ void MultiEncoderXPUInferMeta( const std::vector& fc_bias, const std::vector& ln_scale, const std::vector& ln_bias, + const std::vector& smooth_scale_weight, const MetaTensor& mask, const MetaTensor& seq_lod, const MetaTensor& max_seq_len, @@ -1459,6 +1460,8 @@ void MultiEncoderXPUInferMeta( int relative_type, int slice_idx, bool is_per_channel, + const std::vector& softmax_max_value, + const std::vector& quant_types, MetaTensor* out, MetaTensor* x_fp16, MetaTensor* out_fp16) { diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index 3fccc81f96134..f8e4cb82f6809 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -150,6 +150,7 @@ void MultiEncoderXPUInferMeta( const std::vector& fc_bias, const std::vector& ln_scale, const std::vector& ln_bias, + const std::vector& smooth_scale_weight, const MetaTensor& mask, const MetaTensor& seq_lod, const MetaTensor& max_seq_len, @@ -163,6 +164,8 @@ void MultiEncoderXPUInferMeta( int relative_type, int slice_idx, bool is_per_channel, + const std::vector& softmax_max_value, + const std::vector& quant_types, MetaTensor* out, MetaTensor* x_fp16, MetaTensor* out_fp16); diff --git a/paddle/phi/kernels/fusion/xpu/multi_encoder_xpu_kernel.cc b/paddle/phi/kernels/fusion/xpu/multi_encoder_xpu_kernel.cc index 9cfbcf87f6c6b..1f76fc3ef02d8 100644 --- a/paddle/phi/kernels/fusion/xpu/multi_encoder_xpu_kernel.cc +++ b/paddle/phi/kernels/fusion/xpu/multi_encoder_xpu_kernel.cc @@ -25,7 +25,7 @@ namespace fusion { int r = xpu::transformer_encoder( \ ctx.x_context(), \ x_fp16_data, \ - fc_weight_data_##gemm_dtype_, \ + fc_weight_data_##w_dtype_, \ out_fp16_data, \ fc_input_max_data, \ fc_weight_max_data, \ @@ -37,30 +37,34 @@ namespace fusion { PADDLE_ENFORCE_XDNN_SUCCESS(r, "multi_encoder_xpu"); template -void MultiEncoderXPUKernel(const Context& ctx, - const DenseTensor& x, - const std::vector& fc_input_max, - const std::vector& fc_weight, - const std::vector& fc_weight_max, - const std::vector& fc_bias, - const std::vector& ln_scale, - const std::vector& ln_bias, - const paddle::optional& mask, - const paddle::optional& seq_lod, - const paddle::optional& max_seq_len, - int layer_num, - bool norm_before, - int hidden_dim, - int head_num, - int size_per_head, - int ffn_hidden_dim_scale, - int act_type, - int relative_type, - int slice_idx, - bool is_per_channel, - DenseTensor* out, - DenseTensor* x_fp16, - DenseTensor* out_fp16) { +void MultiEncoderXPUKernel( + const Context& ctx, + const DenseTensor& x, + const std::vector& fc_input_max, + const std::vector& fc_weight, + const std::vector& fc_weight_max, + const std::vector& fc_bias, + const std::vector& ln_scale, + const std::vector& ln_bias, + const std::vector& smooth_scale_weight, + const paddle::optional& mask, + const paddle::optional& seq_lod, + const paddle::optional& max_seq_len, + int layer_num, + bool norm_before, + int hidden_dim, + int head_num, + int size_per_head, + int ffn_hidden_dim_scale, + int act_type, + int relative_type, + int slice_idx, + bool is_per_channel, + const std::vector& softmax_max_value, + const std::vector& quant_types, + DenseTensor* out, + DenseTensor* x_fp16, + DenseTensor* out_fp16) { const int* seq_lod_data = seq_lod.get_ptr() == nullptr ? nullptr : seq_lod.get_ptr()->data(); const int* max_seq_len_data = max_seq_len.get_ptr() == nullptr @@ -103,42 +107,40 @@ void MultiEncoderXPUKernel(const Context& ctx, // q,k,v weight are fused. // Each encoder's weight should be: w0, null, null, w3, w4, w5 auto enable_int8 = fc_weight[0]->dtype() == phi::DataType::INT8; - std::vector quant_types(8 * layer_num, - xpu::QuantType::NOT_QUANT); + auto local_quant = fc_weight[0]->dtype() == phi::DataType::FLOAT16; + std::vector set_quant_types(8 * layer_num, + xpu::QuantType::NOT_QUANT); if (enable_int8) { - int quant_types_size = 0; - if (static_cast(fc_input_max.size()) == 18 * layer_num) { - // quant mul + matmul - quant_types_size = 8; - } else if (static_cast(fc_input_max.size()) == 12 * layer_num) { - // quant mul - quant_types_size = 6; - } else { - PADDLE_ENFORCE_XDNN_SUCCESS( - 1, "fc_input_max size must be 12 * layer_num or 18 * layer_num."); - } - for (int i = 0; i < layer_num; i++) { - for (int j = 0; j < quant_types_size; j++) { - quant_types[i * 8 + j] = xpu::QuantType::QUANT_INT8; - } + for (size_t i = 0; i < quant_types.size(); i++) { + set_quant_types[i] = xpu::QuantType::QUANT_INT8; } } std::vector fc_input_max_data; std::vector fc_weight_data_int16_t; std::vector fc_weight_data_int8_t; + std::vector fc_weight_data_XPUTypeFP16; std::vector fc_weight_max_data; std::vector fc_bias_data; for (size_t i = 0; i < fc_weight.size(); i++) { if (!enable_int8) { - fc_weight_data_int16_t.push_back(fc_weight[i]->data()); + if (local_quant) { + fc_weight_data_XPUTypeFP16.push_back( + reinterpret_cast(fc_weight[i]->data())); + } else { + fc_weight_data_int16_t.push_back( + reinterpret_cast(fc_weight[i]->data())); + } } else { - fc_weight_data_int8_t.push_back(fc_weight[i]->data()); + fc_weight_data_int8_t.push_back( + reinterpret_cast(fc_weight[i]->data())); } fc_weight_max_data.push_back(fc_weight_max[i]->data()); fc_bias_data.push_back(fc_bias[i]->data()); if (i % 4 == 0) { fc_weight_data_int16_t.push_back(nullptr); fc_weight_data_int16_t.push_back(nullptr); + fc_weight_data_XPUTypeFP16.push_back(nullptr); + fc_weight_data_XPUTypeFP16.push_back(nullptr); fc_weight_data_int8_t.push_back(nullptr); fc_weight_data_int8_t.push_back(nullptr); fc_weight_max_data.push_back(nullptr); @@ -180,10 +182,29 @@ void MultiEncoderXPUKernel(const Context& ctx, hidden_dim, norm_before, is_per_channel); - qkv_attn_param.quant_type_.assign(quant_types.begin(), quant_types.end()); + if (!softmax_max_value.empty()) { + qkv_attn_param.ptq_max_value = softmax_max_value; + } + if (!smooth_scale_weight.empty()) { + qkv_attn_param.is_smooth_quant = true; + std::vector smooth_scale_weight_ptr; + for (const auto& weight : smooth_scale_weight) { + auto tmp_ptr = reinterpret_cast( + weight->data()); + smooth_scale_weight_ptr.push_back(tmp_ptr); + } + qkv_attn_param.smooth_scale.assign(smooth_scale_weight_ptr.begin(), + smooth_scale_weight_ptr.end()); + } + qkv_attn_param.quant_type_.assign(set_quant_types.begin(), + set_quant_types.end()); qkv_attn_param.scale_of_hidden_units = ffn_hidden_dim_scale; if (!enable_int8) { - TRANSFORMER_ENCODER_KERNEL_IMPL(XPUTypeFP16, int16_t, int16_t) + if (local_quant) { + TRANSFORMER_ENCODER_KERNEL_IMPL(XPUTypeFP16, XPUTypeFP16, float) + } else { + TRANSFORMER_ENCODER_KERNEL_IMPL(XPUTypeFP16, int16_t, int16_t) + } } else { TRANSFORMER_ENCODER_KERNEL_IMPL(XPUTypeFP16, int8_t, int8_t) } @@ -204,10 +225,29 @@ void MultiEncoderXPUKernel(const Context& ctx, hidden_dim, norm_before, is_per_channel); - qkv_attn_param.quant_type_.assign(quant_types.begin(), quant_types.end()); + if (!softmax_max_value.empty()) { + qkv_attn_param.ptq_max_value = softmax_max_value; + } + if (!smooth_scale_weight.empty()) { + qkv_attn_param.is_smooth_quant = true; + std::vector smooth_scale_weight_ptr; + for (const auto& weight : smooth_scale_weight) { + auto tmp_ptr = reinterpret_cast( + weight->data()); + smooth_scale_weight_ptr.push_back(tmp_ptr); + } + qkv_attn_param.smooth_scale.assign(smooth_scale_weight_ptr.begin(), + smooth_scale_weight_ptr.end()); + } + qkv_attn_param.quant_type_.assign(set_quant_types.begin(), + set_quant_types.end()); qkv_attn_param.scale_of_hidden_units = ffn_hidden_dim_scale; if (!enable_int8) { - TRANSFORMER_ENCODER_KERNEL_IMPL(XPUTypeFP16, int16_t, int16_t) + if (local_quant) { + TRANSFORMER_ENCODER_KERNEL_IMPL(XPUTypeFP16, XPUTypeFP16, float) + } else { + TRANSFORMER_ENCODER_KERNEL_IMPL(XPUTypeFP16, int16_t, int16_t) + } } else { TRANSFORMER_ENCODER_KERNEL_IMPL(XPUTypeFP16, int8_t, int8_t) } @@ -231,10 +271,29 @@ void MultiEncoderXPUKernel(const Context& ctx, hidden_dim, norm_before, is_per_channel); - qkv_attn_param.quant_type_.assign(quant_types.begin(), quant_types.end()); + if (!softmax_max_value.empty()) { + qkv_attn_param.ptq_max_value = softmax_max_value; + } + if (!smooth_scale_weight.empty()) { + qkv_attn_param.is_smooth_quant = true; + std::vector smooth_scale_weight_ptr; + for (const auto& weight : smooth_scale_weight) { + auto tmp_ptr = reinterpret_cast( + weight->data()); + smooth_scale_weight_ptr.push_back(tmp_ptr); + } + qkv_attn_param.smooth_scale.assign(smooth_scale_weight_ptr.begin(), + smooth_scale_weight_ptr.end()); + } + qkv_attn_param.quant_type_.assign(set_quant_types.begin(), + set_quant_types.end()); qkv_attn_param.scale_of_hidden_units = ffn_hidden_dim_scale; if (!enable_int8) { - TRANSFORMER_ENCODER_KERNEL_IMPL(XPUTypeFP16, int16_t, int16_t) + if (local_quant) { + TRANSFORMER_ENCODER_KERNEL_IMPL(XPUTypeFP16, XPUTypeFP16, float) + } else { + TRANSFORMER_ENCODER_KERNEL_IMPL(XPUTypeFP16, int16_t, int16_t) + } } else { TRANSFORMER_ENCODER_KERNEL_IMPL(XPUTypeFP16, int8_t, int8_t) } @@ -260,6 +319,6 @@ PD_REGISTER_KERNEL(multi_encoder_xpu, phi::fusion::MultiEncoderXPUKernel, float, phi::dtype::float16) { - kernel->InputAt(8).SetBackend(phi::Backend::CPU); kernel->InputAt(9).SetBackend(phi::Backend::CPU); + kernel->InputAt(10).SetBackend(phi::Backend::CPU); }