Skip to content

Commit 6f20460

Browse files
[arm-cpu] fix conv+hardswish in int8-int8 compute diff (#8996) (#9001)
1 parent a280a0a commit 6f20460

File tree

5 files changed

+30
-2
lines changed

5 files changed

+30
-2
lines changed

lite/backends/arm/math/gemv_arm_int8.cc

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,6 +1222,8 @@ inline void write_gemv_out(const int* in,
12221222
float32x4_t valpha = vdupq_n_f32(alpha);
12231223
float32x4_t voffset = vdupq_n_f32(offset);
12241224
float32x4_t vthreshold = vdupq_n_f32(threshold);
1225+
float32x4_t vmax = vdupq_n_f32(-127.f);
1226+
12251227
#ifdef __aarch64__
12261228
asm volatile(
12271229
"cmp %w[cnt], #1\n"
@@ -1252,6 +1254,11 @@ inline void write_gemv_out(const int* in,
12521254
"fmul v0.4s, v4.4s, v5.4s\n"
12531255
"fmin v6.4s, v6.4s, %[vthreshold].4s\n"
12541256
"fmul v3.4s, v6.4s, v7.4s\n"
1257+
// out >= -127
1258+
"fcmge v4.4s, v0.4s, %[vmax].4s\n"
1259+
"fcmge v5.4s, v3.4s, %[vmax].4s\n"
1260+
"bif v0.16b, %[vmax].16b, v4.16b\n"
1261+
"bif v3.16b, %[vmax].16b, v5.16b\n"
12551262
// fp32 - int32
12561263
"fcvtas v4.4s, v0.4s\n"
12571264
"fcvtas v5.4s, v3.4s\n"
@@ -1279,6 +1286,9 @@ inline void write_gemv_out(const int* in,
12791286
"fmax v4.4s, v4.4s, %[vzero].4s\n"
12801287
"fmin v4.4s, v4.4s, %[vthreshold].4s\n"
12811288
"fmul v0.4s, v4.4s, v5.4s\n"
1289+
// out >= -127
1290+
"fcmge v4.4s, v0.4s, %[vmax].4s\n"
1291+
"bif v0.16b, %[vmax].16b, v4.16b\n"
12821292
// fp32 - int32
12831293
"fcvtas v4.4s, v0.4s\n"
12841294
// int32 - int16
@@ -1298,7 +1308,8 @@ inline void write_gemv_out(const int* in,
12981308
[vzero] "w"(vzero),
12991309
[valpha] "w"(valpha),
13001310
[voffset] "w"(voffset),
1301-
[vthreshold] "w"(vthreshold)
1311+
[vthreshold] "w"(vthreshold),
1312+
[vmax] "w"(vmax)
13021313
: "cc",
13031314
"memory",
13041315
"v0",
@@ -1349,6 +1360,11 @@ inline void write_gemv_out(const int* in,
13491360
"vbif q13, %q[vfive], q10\n"
13501361
"vadd.f32 q5, q5, q12\n"
13511362
"vadd.f32 q8, q8, q13\n"
1363+
// data >= -127
1364+
"vcge.f32 q7, q5, %q[vmax]\n"
1365+
"vcge.f32 q9, q8, %q[vmax]\n"
1366+
"vbif q5, %q[vmax], q7\n"
1367+
"vbif q8, %q[vmax], q9\n"
13521368
// fp32 -> int32
13531369
"vcvt.s32.f32 q7, q5\n"
13541370
"vcvt.s32.f32 q9, q8\n"
@@ -1380,6 +1396,9 @@ inline void write_gemv_out(const int* in,
13801396
"vcge.f32 q7, q5, %q[vzero]\n"
13811397
"vbif q12, %q[vfive], q7\n"
13821398
"vadd.f32 q5, q5, q12\n"
1399+
// data >= -127
1400+
"vcge.f32 q7, q5, %q[vmax]\n"
1401+
"vbif q5, %q[vmax], q7\n"
13831402
// fp32 -> int32
13841403
"vcvt.s32.f32 q7, q5\n"
13851404
// int32 -> int16
@@ -1400,7 +1419,8 @@ inline void write_gemv_out(const int* in,
14001419
[valpha] "w"(valpha),
14011420
[voffset] "w"(voffset),
14021421
[vthreshold] "w"(vthreshold),
1403-
[vfive] "w"(vfive)
1422+
[vfive] "w"(vfive),
1423+
[vmax] "w"(vmax)
14041424
: "cc",
14051425
"memory",
14061426
"q4",
@@ -1624,6 +1644,7 @@ bool gemv_int8_trans_oth(const int8_t* A,
16241644
memset(zerobuf, 0, sizeof(float) * (M + 16));
16251645
const float* bias_ptr = is_bias ? bias : zerobuf;
16261646
float six = alpha;
1647+
16271648
#ifdef __aarch64__
16281649
int cnt = N >> 3;
16291650
int tail = N & 7;

lite/kernels/arm/conv_direct.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ inline bool direct_conv_trans_weights<PRECISION(kInt8), PRECISION(kInt8)>(
154154
}
155155
//! update hardswish parameter
156156
if (act_param.active_type == lite_api::ActivationType::kHardSwish) {
157+
act_param.hard_swish_scale = act_param.hard_swish_scale / out_scale;
157158
act_param.hard_swish_offset = act_param.hard_swish_offset / out_scale;
158159
act_param.hard_swish_threshold = act_param.hard_swish_threshold / out_scale;
159160
}

lite/kernels/arm/conv_gemmlike.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ void GemmLikeConv<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
8686
//! update hardswish parameter
8787
if (param.activation_param.active_type ==
8888
lite_api::ActivationType::kHardSwish) {
89+
param.activation_param.hard_swish_scale =
90+
param.activation_param.hard_swish_scale / param.output_scale;
8991
param.activation_param.hard_swish_offset =
9092
param.activation_param.hard_swish_offset / param.output_scale;
9193
param.activation_param.hard_swish_threshold =

lite/kernels/arm/conv_transpose_compute.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,8 @@ void Conv2DTransposeCompute<PRECISION(kInt8),
154154
//! update hardswish parameter
155155
if (param.activation_param.active_type ==
156156
lite_api::ActivationType::kHardSwish) {
157+
param.activation_param.hard_swish_scale =
158+
param.activation_param.hard_swish_scale / param.output_scale;
157159
param.activation_param.hard_swish_offset =
158160
param.activation_param.hard_swish_offset / param.output_scale;
159161
param.activation_param.hard_swish_threshold =

lite/kernels/arm/conv_winograd.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,8 @@ void WinogradConv<PRECISION(kInt8), OutType>::ReInitWhenNeeded() {
196196
//! update hardswish parameter
197197
if (param.activation_param.active_type ==
198198
lite_api::ActivationType::kHardSwish) {
199+
param.activation_param.hard_swish_scale =
200+
param.activation_param.hard_swish_scale / param.output_scale;
199201
param.activation_param.hard_swish_offset =
200202
param.activation_param.hard_swish_offset / output_scale;
201203
param.activation_param.hard_swish_threshold =

0 commit comments

Comments
 (0)