Skip to content

Commit b6583ca

Browse files
committed
add quant_dequant_moving_avg_max_abs op
test=develop
1 parent 75cda4d commit b6583ca

File tree

4 files changed

+163
-45
lines changed

4 files changed

+163
-45
lines changed

paddle/fluid/operators/fake_quantize_op.cc

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,23 @@ struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
6868

6969
template struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, float>;
7070

71+
template <typename T>
72+
struct ClipAndFakeQuantDequantFunctor<platform::CPUDeviceContext, T> {
73+
void operator()(const platform::CPUDeviceContext& ctx,
74+
const framework::Tensor& in, const framework::Tensor& scale,
75+
const int bin_cnt, framework::Tensor* out) {
76+
T s = scale.data<T>()[0];
77+
platform::Transform<platform::CPUDeviceContext> trans;
78+
trans(ctx, in.data<T>(), in.data<T>() + in.numel(),
79+
out->mutable_data<T>(ctx.GetPlace()), ClipFunctor<T>(-s, s));
80+
auto out_e = framework::EigenVector<T>::Flatten(*out);
81+
out_e.device(*ctx.eigen_device()) =
82+
(s / bin_cnt) * (bin_cnt / s * out_e).round();
83+
}
84+
};
85+
template struct ClipAndFakeQuantDequantFunctor<platform::CPUDeviceContext,
86+
float>;
87+
7188
template <typename T>
7289
struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
7390
void operator()(const platform::CPUDeviceContext& ctx,
@@ -480,8 +497,17 @@ REGISTER_OPERATOR(fake_quantize_moving_average_abs_max,
480497
ops::FakeQuantizeMovingAverageAbsMaxOp,
481498
ops::FakeQuantizeMovingAverageAbsMaxOpMaker,
482499
paddle::framework::EmptyGradOpMaker);
500+
483501
REGISTER_OP_CPU_KERNEL(fake_quantize_moving_average_abs_max,
484502
ops::FakeQuantizeMovingAverageAbsMaxKernel<CPU, float>);
503+
REGISTER_OPERATOR(fake_quantize_dequantize_moving_average_abs_max,
504+
ops::FakeQuantizeMovingAverageAbsMaxOp,
505+
ops::FakeQuantizeMovingAverageAbsMaxOpMaker,
506+
paddle::framework::EmptyGradOpMaker);
507+
REGISTER_OP_CPU_KERNEL(
508+
fake_quantize_dequantize_moving_average_abs_max,
509+
ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel<CPU, float>);
510+
485511
REGISTER_OPERATOR(fake_channel_wise_quantize_abs_max,
486512
ops::FakeChannelWiseQuantizeAbsMaxOp,
487513
ops::FakeChannelWiseQuantizeAbsMaxOpMaker,

paddle/fluid/operators/fake_quantize_op.cu

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,23 @@ __global__ void ClipAndQuantKernel(const T* in, const T* scale,
129129
}
130130
}
131131

132+
template <typename T>
133+
__global__ void ClipAndQuantDequantKernel(const T* in, const T* scale,
134+
const int bin_cnt, const int n,
135+
T* out) {
136+
int bid = threadIdx.x + blockIdx.x * blockDim.x;
137+
int tid = threadIdx.x;
138+
139+
T s = scale[0];
140+
for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
141+
T x = in[i];
142+
T v = x > s ? s : x;
143+
v = v < -s ? -s : v;
144+
v = bin_cnt / s * v;
145+
out[i] = round(v) * s / bin_cnt;
146+
}
147+
}
148+
132149
template <typename T>
133150
struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
134151
void operator()(const platform::CUDADeviceContext& ctx,
@@ -149,6 +166,27 @@ struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
149166

150167
template struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, float>;
151168

169+
template <typename T>
170+
struct ClipAndFakeQuantDequantFunctor<platform::CUDADeviceContext, T> {
171+
void operator()(const platform::CUDADeviceContext& ctx,
172+
const framework::Tensor& in, const framework::Tensor& scale,
173+
const int bin_cnt, framework::Tensor* out) {
174+
int num = in.numel();
175+
int block = 1024;
176+
int grid = (block - 1 + num) / block;
177+
178+
const T* in_data = in.data<T>();
179+
const T* scale_data = scale.data<T>();
180+
T* out_data = out->mutable_data<T>(ctx.GetPlace());
181+
182+
ClipAndQuantDequantKernel<T><<<grid, block, 0, ctx.stream()>>>(
183+
in_data, scale_data, bin_cnt, num, out_data);
184+
}
185+
};
186+
187+
template struct ClipAndFakeQuantDequantFunctor<platform::CUDADeviceContext,
188+
float>;
189+
152190
template <typename T>
153191
__global__ void ChannelClipAndQuantKernel(const T* in, const T* scale,
154192
const int bin_cnt, const int n,
@@ -302,3 +340,6 @@ REGISTER_OP_CUDA_KERNEL(
302340
ops::FakeQuantizeMovingAverageAbsMaxKernel<CUDA, float>);
303341
REGISTER_OP_CUDA_KERNEL(moving_average_abs_max_scale,
304342
ops::MovingAverageAbsMaxScaleKernel<CUDA, float>);
343+
REGISTER_OP_CUDA_KERNEL(
344+
fake_quantize_dequantize_moving_average_abs_max,
345+
ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel<CUDA, float>);

paddle/fluid/operators/fake_quantize_op.h

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,13 @@ struct ClipAndFakeQuantFunctor {
3535
framework::Tensor* out);
3636
};
3737

38+
template <typename DeviceContext, typename T>
39+
struct ClipAndFakeQuantDequantFunctor {
40+
void operator()(const DeviceContext& ctx, const framework::Tensor& in,
41+
const framework::Tensor& scale, const int bin_cnt,
42+
framework::Tensor* out);
43+
};
44+
3845
template <typename DeviceContext, typename T>
3946
struct FindRangeAbsMaxFunctor {
4047
void operator()(const DeviceContext& ctx, const framework::Tensor& cur_scale,
@@ -150,8 +157,13 @@ class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel<T> {
150157
};
151158

152159
template <typename DeviceContext, typename T>
153-
class FakeQuantizeMovingAverageAbsMaxKernel : public framework::OpKernel<T> {
160+
class FakeMovingAverageAbsMaxKernelBase : public framework::OpKernel<T> {
154161
public:
162+
~FakeMovingAverageAbsMaxKernelBase() {}
163+
virtual void RunClipFunctor(const DeviceContext& dev_ctx,
164+
const framework::Tensor& in,
165+
const framework::Tensor& in_scale, int bin_cnt,
166+
framework::Tensor* out) const = 0;
155167
void Compute(const framework::ExecutionContext& context) const override {
156168
auto* in = context.Input<framework::Tensor>("X");
157169
auto* in_scale = context.Input<framework::Tensor>("InScale");
@@ -165,8 +177,7 @@ class FakeQuantizeMovingAverageAbsMaxKernel : public framework::OpKernel<T> {
165177

166178
// testing
167179
if (is_test) {
168-
ClipAndFakeQuantFunctor<DeviceContext, T>()(dev_ctx, *in, *in_scale,
169-
bin_cnt, out);
180+
RunClipFunctor(dev_ctx, *in, *in_scale, bin_cnt, out);
170181
return;
171182
}
172183

@@ -193,8 +204,31 @@ class FakeQuantizeMovingAverageAbsMaxKernel : public framework::OpKernel<T> {
193204
dev_ctx, *in_accum, *in_state, cur_scale_data, moving_rate, out_state,
194205
out_accum, out_scale);
195206

196-
ClipAndFakeQuantFunctor<DeviceContext, T>()(dev_ctx, *in, *out_scale,
197-
bin_cnt, out);
207+
RunClipFunctor(dev_ctx, *in, *out_scale, bin_cnt, out);
208+
}
209+
};
210+
211+
template <typename DeviceContext, typename T>
212+
class FakeQuantizeMovingAverageAbsMaxKernel
213+
: public FakeMovingAverageAbsMaxKernelBase<DeviceContext, T> {
214+
public:
215+
void RunClipFunctor(const DeviceContext& dev_ctx, const framework::Tensor& in,
216+
const framework::Tensor& in_scale, int bin_cnt,
217+
framework::Tensor* out) const override {
218+
ClipAndFakeQuantFunctor<DeviceContext, T>()(dev_ctx, in, in_scale, bin_cnt,
219+
out);
220+
}
221+
};
222+
223+
template <typename DeviceContext, typename T>
224+
class FakeQuantizeDequantizeMovingAverageAbsMaxKernel
225+
: public FakeMovingAverageAbsMaxKernelBase<DeviceContext, T> {
226+
public:
227+
void RunClipFunctor(const DeviceContext& dev_ctx, const framework::Tensor& in,
228+
const framework::Tensor& in_scale, int bin_cnt,
229+
framework::Tensor* out) const override {
230+
ClipAndFakeQuantDequantFunctor<DeviceContext, T>()(dev_ctx, in, in_scale,
231+
bin_cnt, out);
198232
}
199233
};
200234

python/paddle/fluid/tests/unittests/test_fake_quantize_op.py

Lines changed: 57 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -90,46 +90,6 @@ def test_check_output(self):
9090
self.check_output()
9191

9292

93-
class TestFakeQuantizeMovingOp(OpTest):
94-
def setUp(self):
95-
self.op_type = "fake_quantize_moving_average_abs_max"
96-
self.attrs = {
97-
'bit_length': int(5),
98-
'moving_rate': float(0.9),
99-
'is_test': False
100-
}
101-
accum = np.zeros(1).astype("float32")
102-
accum[0] = 1
103-
state = np.zeros(1).astype("float32")
104-
state[0] = 1
105-
scale = np.zeros(1).astype("float32")
106-
scale[0] = 0.001
107-
self.inputs = {
108-
'X': np.random.random((8, 16, 7, 7)).astype("float32"),
109-
'InScale': scale,
110-
'InAccum': accum,
111-
'InState': state,
112-
}
113-
114-
out_accum = np.zeros(1).astype("float32")
115-
out_state = np.zeros(1).astype("float32")
116-
out_scale = np.zeros(1).astype("float32")
117-
out_accum[0] = self.attrs['moving_rate'] * accum[0] + np.max(
118-
np.abs(self.inputs['X'])).astype("float32")
119-
out_state[0] = self.attrs['moving_rate'] * state[0] + 1
120-
out_scale = out_accum / out_state
121-
self.outputs = {
122-
'Out': np.round(self.inputs['X'] / out_scale * (
123-
(1 << (self.attrs['bit_length'] - 1)) - 1)),
124-
'OutAccum': out_accum,
125-
'OutState': out_state,
126-
'OutScale': out_scale,
127-
}
128-
129-
def test_check_output(self):
130-
self.check_output()
131-
132-
13393
class TestMovingAverageAbsMaxScaleOp(OpTest):
13494
def setUp(self):
13595
self.op_type = "moving_average_abs_max_scale"
@@ -193,5 +153,62 @@ def test_check_output(self):
193153
self.check_output(no_check_set=set(['OutScale', 'OutScales']))
194154

195155

156+
class TestMovingOpBase(OpTest):
157+
def setUp(self):
158+
self.init_type()
159+
self.attrs = {
160+
'bit_length': int(5),
161+
'moving_rate': float(0.9),
162+
'is_test': False
163+
}
164+
accum = np.zeros(1).astype("float32")
165+
accum[0] = 1
166+
state = np.zeros(1).astype("float32")
167+
state[0] = 1
168+
scale = np.zeros(1).astype("float32")
169+
scale[0] = 0.001
170+
self.inputs = {
171+
'X': np.random.random((8, 16, 7, 7)).astype("float32"),
172+
'InScale': scale,
173+
'InAccum': accum,
174+
'InState': state,
175+
}
176+
177+
out_accum = np.zeros(1).astype("float32")
178+
out_state = np.zeros(1).astype("float32")
179+
out_scale = np.zeros(1).astype("float32")
180+
out_accum[0] = self.attrs['moving_rate'] * accum[0] + np.max(
181+
np.abs(self.inputs['X'])).astype("float32")
182+
out_state[0] = self.attrs['moving_rate'] * state[0] + 1
183+
out_scale = out_accum / out_state
184+
out_data = self.calc_output(out_scale)
185+
self.outputs = {
186+
'Out': out_data,
187+
'OutAccum': out_accum,
188+
'OutState': out_state,
189+
'OutScale': out_scale,
190+
}
191+
192+
def init_type(self):
193+
self.op_type = "fake_quantize_moving_average_abs_max"
194+
195+
def calc_output(self, out_scale):
196+
return np.round(self.inputs['X'] / out_scale * (
197+
(1 << (self.attrs['bit_length'] - 1)) - 1))
198+
199+
def test_check_output(self):
200+
self.check_output()
201+
202+
203+
class TestFakeQuantDequantMovingOp(TestMovingOpBase):
204+
def init_type(self):
205+
self.op_type = "fake_quantize_dequantize_moving_average_abs_max"
206+
207+
def calc_output(self, out_scale):
208+
range_v = (1 << (self.attrs['bit_length'] - 1)) - 1
209+
return np.round(self.inputs['X'] / out_scale *
210+
range_v) * out_scale / range_v
211+
212+
196213
if __name__ == "__main__":
197214
unittest.main()

0 commit comments

Comments
 (0)