From e184eb00c059f68ea714e6c30e66544b26cbacc0 Mon Sep 17 00:00:00 2001 From: HappyAngel Date: Mon, 9 May 2022 19:13:55 +0800 Subject: [PATCH] [arm-cpu] fix conv+hardswish in int8-int8 compute diff (#8996) --- lite/backends/arm/math/gemv_arm_int8.cc | 25 ++++++++++++++++++++-- lite/kernels/arm/conv_direct.h | 1 + lite/kernels/arm/conv_gemmlike.cc | 2 ++ lite/kernels/arm/conv_transpose_compute.cc | 2 ++ lite/kernels/arm/conv_winograd.cc | 2 ++ 5 files changed, 30 insertions(+), 2 deletions(-) diff --git a/lite/backends/arm/math/gemv_arm_int8.cc b/lite/backends/arm/math/gemv_arm_int8.cc index 8f7e0b0474b..34be515f477 100644 --- a/lite/backends/arm/math/gemv_arm_int8.cc +++ b/lite/backends/arm/math/gemv_arm_int8.cc @@ -1222,6 +1222,8 @@ inline void write_gemv_out(const int* in, float32x4_t valpha = vdupq_n_f32(alpha); float32x4_t voffset = vdupq_n_f32(offset); float32x4_t vthreshold = vdupq_n_f32(threshold); + float32x4_t vmax = vdupq_n_f32(-127.f); + #ifdef __aarch64__ asm volatile( "cmp %w[cnt], #1\n" @@ -1252,6 +1254,11 @@ inline void write_gemv_out(const int* in, "fmul v0.4s, v4.4s, v5.4s\n" "fmin v6.4s, v6.4s, %[vthreshold].4s\n" "fmul v3.4s, v6.4s, v7.4s\n" + // out >= -127 + "fcmge v4.4s, v0.4s, %[vmax].4s\n" + "fcmge v5.4s, v3.4s, %[vmax].4s\n" + "bif v0.16b, %[vmax].16b, v4.16b\n" + "bif v3.16b, %[vmax].16b, v5.16b\n" // fp32 - int32 "fcvtas v4.4s, v0.4s\n" "fcvtas v5.4s, v3.4s\n" @@ -1279,6 +1286,9 @@ inline void write_gemv_out(const int* in, "fmax v4.4s, v4.4s, %[vzero].4s\n" "fmin v4.4s, v4.4s, %[vthreshold].4s\n" "fmul v0.4s, v4.4s, v5.4s\n" + // out >= -127 + "fcmge v4.4s, v0.4s, %[vmax].4s\n" + "bif v0.16b, %[vmax].16b, v4.16b\n" // fp32 - int32 "fcvtas v4.4s, v0.4s\n" // int32 - int16 @@ -1298,7 +1308,8 @@ inline void write_gemv_out(const int* in, [vzero] "w"(vzero), [valpha] "w"(valpha), [voffset] "w"(voffset), - [vthreshold] "w"(vthreshold) + [vthreshold] "w"(vthreshold), + [vmax] "w"(vmax) : "cc", "memory", "v0", @@ -1349,6 +1360,11 @@ inline void write_gemv_out(const int* in, "vbif q13, %q[vfive], q10\n" "vadd.f32 q5, q5, q12\n" "vadd.f32 q8, q8, q13\n" + // data >= -127 + "vcge.f32 q7, q5, %q[vmax]\n" + "vcge.f32 q9, q8, %q[vmax]\n" + "vbif q5, %q[vmax], q7\n" + "vbif q8, %q[vmax], q9\n" // fp32 -> int32 "vcvt.s32.f32 q7, q5\n" "vcvt.s32.f32 q9, q8\n" @@ -1380,6 +1396,9 @@ inline void write_gemv_out(const int* in, "vcge.f32 q7, q5, %q[vzero]\n" "vbif q12, %q[vfive], q7\n" "vadd.f32 q5, q5, q12\n" + // data >= -127 + "vcge.f32 q7, q5, %q[vmax]\n" + "vbif q5, %q[vmax], q7\n" // fp32 -> int32 "vcvt.s32.f32 q7, q5\n" // int32 -> int16 @@ -1400,7 +1419,8 @@ inline void write_gemv_out(const int* in, [valpha] "w"(valpha), [voffset] "w"(voffset), [vthreshold] "w"(vthreshold), - [vfive] "w"(vfive) + [vfive] "w"(vfive), + [vmax] "w"(vmax) : "cc", "memory", "q4", @@ -1624,6 +1644,7 @@ bool gemv_int8_trans_oth(const int8_t* A, memset(zerobuf, 0, sizeof(float) * (M + 16)); const float* bias_ptr = is_bias ? bias : zerobuf; float six = alpha; + #ifdef __aarch64__ int cnt = N >> 3; int tail = N & 7; diff --git a/lite/kernels/arm/conv_direct.h b/lite/kernels/arm/conv_direct.h index 6276ebc0f3e..169f0c15496 100644 --- a/lite/kernels/arm/conv_direct.h +++ b/lite/kernels/arm/conv_direct.h @@ -154,6 +154,7 @@ inline bool direct_conv_trans_weights( } //! update hardswish parameter if (act_param.active_type == lite_api::ActivationType::kHardSwish) { + act_param.hard_swish_scale = act_param.hard_swish_scale / out_scale; act_param.hard_swish_offset = act_param.hard_swish_offset / out_scale; act_param.hard_swish_threshold = act_param.hard_swish_threshold / out_scale; } diff --git a/lite/kernels/arm/conv_gemmlike.cc b/lite/kernels/arm/conv_gemmlike.cc index b986f96ad3a..bcc89fe3a45 100644 --- a/lite/kernels/arm/conv_gemmlike.cc +++ b/lite/kernels/arm/conv_gemmlike.cc @@ -86,6 +86,8 @@ void GemmLikeConv::PrepareForRun() { //! update hardswish parameter if (param.activation_param.active_type == lite_api::ActivationType::kHardSwish) { + param.activation_param.hard_swish_scale = + param.activation_param.hard_swish_scale / param.output_scale; param.activation_param.hard_swish_offset = param.activation_param.hard_swish_offset / param.output_scale; param.activation_param.hard_swish_threshold = diff --git a/lite/kernels/arm/conv_transpose_compute.cc b/lite/kernels/arm/conv_transpose_compute.cc index 23aaeca5b35..b81978116f7 100644 --- a/lite/kernels/arm/conv_transpose_compute.cc +++ b/lite/kernels/arm/conv_transpose_compute.cc @@ -154,6 +154,8 @@ void Conv2DTransposeCompute::ReInitWhenNeeded() { //! update hardswish parameter if (param.activation_param.active_type == lite_api::ActivationType::kHardSwish) { + param.activation_param.hard_swish_scale = + param.activation_param.hard_swish_scale / param.output_scale; param.activation_param.hard_swish_offset = param.activation_param.hard_swish_offset / output_scale; param.activation_param.hard_swish_threshold =