Skip to content

Commit bd67209

Browse files
authored
[XPU][PHI Kernels] add int_with_ll quantization for conv kernels (#54827)
* add int_with_ll to conv * fix bugs when output_size is specified for conv2d_transpose
1 parent 9c2dae1 commit bd67209

File tree

3 files changed

+153
-8
lines changed

3 files changed

+153
-8
lines changed

paddle/phi/kernels/xpu/conv_grad_kernel.cc

+55-4
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ void ConvGradKernel(const Context& dev_ctx,
107107
}
108108
}
109109
int fccal_type = FCCalcType<XPUT>();
110-
if (fccal_type == 1) {
110+
if (fccal_type == XPUFCCalcType::FC_INT32) {
111111
int r = xpu::conv2d_grad<XPUT, XPUT, XPUT, int>(dev_ctx.x_context(),
112112
input_data,
113113
filter_data_ptr,
@@ -132,7 +132,7 @@ void ConvGradKernel(const Context& dev_ctx,
132132
is_nchw);
133133
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_grad");
134134

135-
} else if (fccal_type == 2) {
135+
} else if (fccal_type == XPUFCCalcType::FC_FLOAT) {
136136
int r = xpu::conv2d_grad<XPUT, XPUT, XPUT, float>(dev_ctx.x_context(),
137137
input_data,
138138
filter_data_ptr,
@@ -157,6 +157,31 @@ void ConvGradKernel(const Context& dev_ctx,
157157
is_nchw);
158158
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_grad");
159159

160+
} else if (fccal_type == XPUFCCalcType::FC_INT32_WITH_LL) {
161+
int r =
162+
xpu::conv2d_grad<XPUT, XPUT, XPUT, int_with_ll_t>(dev_ctx.x_context(),
163+
input_data,
164+
filter_data_ptr,
165+
output_grad_data,
166+
input_grad_data,
167+
filter_grad_data_ptr,
168+
batch_size,
169+
img_c,
170+
img_h,
171+
img_w,
172+
f,
173+
ksize,
174+
strides,
175+
paddings,
176+
dilations,
177+
groups,
178+
nullptr,
179+
nullptr,
180+
nullptr,
181+
nullptr,
182+
nullptr,
183+
is_nchw);
184+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_grad");
160185
} else {
161186
int r = xpu::conv2d_grad<XPUT, XPUT, XPUT, int16_t>(dev_ctx.x_context(),
162187
input_data,
@@ -305,7 +330,7 @@ void Conv3DGradKernel(const Context& dev_ctx,
305330
}
306331
}
307332
int fccal_type = FCCalcType<XPUT>();
308-
if (fccal_type == 1) {
333+
if (fccal_type == XPUFCCalcType::FC_INT32) {
309334
int r = xpu::conv3d_grad<XPUT, XPUT, XPUT, int>(dev_ctx.x_context(),
310335
input_data,
311336
filter_data_ptr,
@@ -330,7 +355,7 @@ void Conv3DGradKernel(const Context& dev_ctx,
330355
nullptr,
331356
is_ncdhw);
332357
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv3d_grad");
333-
} else if (fccal_type == 2) {
358+
} else if (fccal_type == XPUFCCalcType::FC_FLOAT) {
334359
int r = xpu::conv3d_grad<XPUT, XPUT, XPUT, float>(dev_ctx.x_context(),
335360
input_data,
336361
filter_data_ptr,
@@ -355,6 +380,32 @@ void Conv3DGradKernel(const Context& dev_ctx,
355380
nullptr,
356381
is_ncdhw);
357382
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv3d_grad");
383+
} else if (fccal_type == XPUFCCalcType::FC_INT32_WITH_LL) {
384+
int r =
385+
xpu::conv3d_grad<XPUT, XPUT, XPUT, int_with_ll_t>(dev_ctx.x_context(),
386+
input_data,
387+
filter_data_ptr,
388+
output_grad_data,
389+
input_grad_data,
390+
filter_grad_data_ptr,
391+
batch_size,
392+
img_c,
393+
img_d,
394+
img_h,
395+
img_w,
396+
f,
397+
ksize,
398+
strides,
399+
paddings,
400+
dilations,
401+
groups,
402+
nullptr,
403+
nullptr,
404+
nullptr,
405+
nullptr,
406+
nullptr,
407+
is_ncdhw);
408+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv3d_grad");
358409
} else {
359410
int r = xpu::conv3d_grad<XPUT, XPUT, XPUT, int16_t>(dev_ctx.x_context(),
360411
input_data,

paddle/phi/kernels/xpu/conv_kernel.cc

+45-4
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ void ConvKernel(const Context& dev_ctx,
8989
}
9090

9191
int fccal_type = FCCalcType<XPUT>();
92-
if (fccal_type == 1) {
92+
if (fccal_type == XPUFCCalcType::FC_INT32) {
9393
int r = xpu::conv2d<XPUT, XPUT, XPUT, int>(dev_ctx.x_context(),
9494
input_data,
9595
filter_data_ptr,
@@ -109,7 +109,7 @@ void ConvKernel(const Context& dev_ctx,
109109
nullptr,
110110
is_nchw);
111111
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d");
112-
} else if (fccal_type == 2) {
112+
} else if (fccal_type == XPUFCCalcType::FC_FLOAT) {
113113
int r = xpu::conv2d<XPUT, XPUT, XPUT, float>(dev_ctx.x_context(),
114114
input_data,
115115
filter_data_ptr,
@@ -129,6 +129,26 @@ void ConvKernel(const Context& dev_ctx,
129129
nullptr,
130130
is_nchw);
131131
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d");
132+
} else if (fccal_type == XPUFCCalcType::FC_INT32_WITH_LL) {
133+
int r = xpu::conv2d<XPUT, XPUT, XPUT, int_with_ll_t>(dev_ctx.x_context(),
134+
input_data,
135+
filter_data_ptr,
136+
output_data,
137+
batch_size,
138+
img_c,
139+
img_h,
140+
img_w,
141+
f,
142+
ksize,
143+
strides,
144+
paddings,
145+
dilations,
146+
groups,
147+
nullptr,
148+
nullptr,
149+
nullptr,
150+
is_nchw);
151+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d");
132152
} else {
133153
int r = xpu::conv2d<XPUT, XPUT, XPUT, int16_t>(dev_ctx.x_context(),
134154
input_data,
@@ -239,7 +259,7 @@ void Conv3DKernel(const Context& dev_ctx,
239259
}
240260

241261
int fccal_type = FCCalcType<XPUT>();
242-
if (fccal_type == 1) {
262+
if (fccal_type == XPUFCCalcType::FC_INT32) {
243263
int r = xpu::conv3d<XPUT, XPUT, XPUT, int>(dev_ctx.x_context(),
244264
input_data,
245265
filter_data_ptr,
@@ -260,7 +280,7 @@ void Conv3DKernel(const Context& dev_ctx,
260280
nullptr,
261281
is_ncdhw);
262282
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv3d");
263-
} else if (fccal_type == 2) {
283+
} else if (fccal_type == XPUFCCalcType::FC_FLOAT) {
264284
int r = xpu::conv3d<XPUT, XPUT, XPUT, float>(dev_ctx.x_context(),
265285
input_data,
266286
filter_data_ptr,
@@ -282,6 +302,27 @@ void Conv3DKernel(const Context& dev_ctx,
282302
is_ncdhw);
283303
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv3d");
284304

305+
} else if (fccal_type == XPUFCCalcType::FC_INT32_WITH_LL) {
306+
int r = xpu::conv3d<XPUT, XPUT, XPUT, int_with_ll_t>(dev_ctx.x_context(),
307+
input_data,
308+
filter_data_ptr,
309+
output_data,
310+
batch_size,
311+
img_c,
312+
img_d,
313+
img_h,
314+
img_w,
315+
f,
316+
ksize,
317+
strides,
318+
paddings,
319+
dilations,
320+
groups,
321+
nullptr,
322+
nullptr,
323+
nullptr,
324+
is_ncdhw);
325+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv3d");
285326
} else {
286327
int r = xpu::conv3d<XPUT, XPUT, XPUT, int16_t>(dev_ctx.x_context(),
287328
input_data,

paddle/phi/kernels/xpu/conv_transpose_kernel.cc

+53
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
#include "paddle/phi/kernels/conv_transpose_kernel.h"
1616

17+
#include "glog/logging.h"
18+
1719
#include "paddle/phi/backends/xpu/enforce_xpu.h"
1820
#include "paddle/phi/core/kernel_registry.h"
1921
#include "paddle/phi/kernels/cpu/conv_util.h"
@@ -122,6 +124,57 @@ void Conv2dTransposeKernel(const Context& ctx,
122124
nullptr,
123125
true);
124126
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_transpose_v2");
127+
} else if (fccal_type == XPUFCCalcType::FC_INT32_WITH_LL) {
128+
if (output_size.size()) {
129+
VLOG(4) << "int_with_ll quantization is not supported when output_size "
130+
"is specified, "
131+
<< "use int31 instead";
132+
int r = xpu::conv2d_transpose_v2<float, float, float, int32_t>(
133+
ctx.x_context(),
134+
x.data<float>(),
135+
filter_.data<float>(),
136+
out->data<float>(),
137+
batch_size,
138+
img_yc,
139+
img_xh,
140+
img_xw,
141+
img_xc,
142+
ksize,
143+
strides,
144+
paddings_,
145+
dilations_,
146+
groups,
147+
nullptr,
148+
nullptr,
149+
nullptr,
150+
true);
151+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_transpose_v2");
152+
} else {
153+
// xpu::conv2d_transpose_v2 do not support int_with_ll now
154+
// use xpu::conv2d_transpose
155+
int img_yh = static_cast<int>(x.dims()[2]);
156+
int img_yw = static_cast<int>(x.dims()[3]);
157+
int r = xpu::conv2d_transpose<float, float, float, int_with_ll_t>(
158+
ctx.x_context(),
159+
x.data<float>(),
160+
filter_.data<float>(),
161+
out->data<float>(),
162+
batch_size,
163+
img_yc,
164+
img_yh,
165+
img_yw,
166+
img_xc,
167+
ksize,
168+
strides,
169+
paddings_,
170+
dilations_,
171+
groups,
172+
nullptr,
173+
nullptr,
174+
nullptr,
175+
true);
176+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_transpose");
177+
}
125178
} else {
126179
int r = xpu::conv2d_transpose_v2<XPUT, XPUT, XPUT, int16_t>(
127180
ctx.x_context(),

0 commit comments

Comments
 (0)