Skip to content

Commit a177d48

Browse files
xiaolil1luotao1
authored andcommitted
Add Requantize OP (#15318)
* Enable INT8 ReQuantize OP test=develop * Clean code test=develop * Add comments test=develop * Revert "Clean code" test=develop This reverts commit a7a49b8. * Modify requantize op test test=develop * fix requantize UT by moving public function to public test file. test=develop * Fix test fail due to file address change. test=develop * Change file address for requantize op. test=develop
1 parent f5a3751 commit a177d48

File tree

6 files changed

+295
-14
lines changed

6 files changed

+295
-14
lines changed
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "mkldnn.hpp"
16+
#include "paddle/fluid/framework/data_layout_transform.h"
17+
#include "paddle/fluid/framework/tensor.h"
18+
#include "paddle/fluid/operators/requantize_op.h"
19+
#include "paddle/fluid/platform/mkldnn_helper.h"
20+
21+
namespace paddle {
22+
namespace operators {
23+
24+
using mkldnn::memory;
25+
using mkldnn::primitive;
26+
using mkldnn::reorder;
27+
using platform::to_void_cast;
28+
using Tensor = framework::Tensor;
29+
using framework::DataLayout;
30+
using mkldnn::stream;
31+
using platform::GetMKLDNNFormat;
32+
33+
template <typename T>
34+
class ReQuantOpKernel : public framework::OpKernel<T> {
35+
public:
36+
void Compute(const framework::ExecutionContext& ctx) const override {
37+
auto* input = ctx.Input<Tensor>("Input");
38+
auto scale_in = ctx.Attr<float>("Scale_in");
39+
auto scale_out = ctx.Attr<float>("Scale_out");
40+
auto* output = ctx.Output<Tensor>("Output");
41+
auto& dev_ctx =
42+
ctx.template device_context<platform::MKLDNNDeviceContext>();
43+
const auto& engine = dev_ctx.GetEngine();
44+
45+
std::vector<primitive> pipeline;
46+
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
47+
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
48+
mkldnn::memory::data_type src_dt =
49+
paddle::framework::ToMKLDNNDataType(input->type());
50+
mkldnn::memory::data_type dst_dt = src_dt; // TODO(Xiaoli) support
51+
// requantize from different
52+
// data type (e.g., s8 to u8)
53+
mkldnn::memory::format src_fmt = memory::format::nhwc;
54+
mkldnn::memory::format dst_fmt = memory::format::nhwc;
55+
56+
const T* input_data = input->data<T>();
57+
T* output_data = output->mutable_data<T>(ctx.GetPlace());
58+
float scale_shift = scale_out / scale_in;
59+
60+
mkldnn::primitive_attr attri;
61+
int mask = 0;
62+
attri.set_output_scales(mask, {scale_shift});
63+
64+
auto src_md = platform::MKLDNNMemDesc({src_tz}, src_dt, src_fmt);
65+
auto src_pd = mkldnn::memory::primitive_desc(src_md, engine);
66+
auto src_memory =
67+
std::make_shared<mkldnn::memory>(src_pd, to_void_cast<T>(input_data));
68+
std::shared_ptr<primitive::at> src_memory_p =
69+
std::shared_ptr<primitive::at>(new primitive::at(*src_memory));
70+
71+
auto dst_md = platform::MKLDNNMemDesc({dst_tz}, dst_dt, dst_fmt);
72+
auto dst_pd = mkldnn::memory::primitive_desc(dst_md, engine);
73+
auto dst_memory = mkldnn::memory(dst_pd, to_void_cast<T>(output_data));
74+
75+
auto reorder_pd = std::shared_ptr<reorder::primitive_desc>(
76+
new reorder::primitive_desc(src_pd, dst_pd, attri));
77+
78+
auto reorder_p = std::shared_ptr<reorder>(
79+
new reorder(*reorder_pd, *src_memory_p, dst_memory));
80+
pipeline.push_back(*reorder_p);
81+
stream(stream::kind::eager).submit(pipeline).wait();
82+
83+
output->set_layout(DataLayout::kMKLDNN);
84+
output->set_format(GetMKLDNNFormat(dst_memory));
85+
}
86+
};
87+
88+
} // namespace operators
89+
} // namespace paddle
90+
91+
namespace ops = paddle::operators;
92+
93+
REGISTER_OP_KERNEL(requantize, MKLDNN, ::paddle::platform::CPUPlace,
94+
ops::ReQuantOpKernel<int8_t>, ops::ReQuantOpKernel<uint8_t>);
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License. */
14+
15+
#include "paddle/fluid/operators/requantize_op.h"
16+
#ifdef PADDLE_WITH_MKLDNN
17+
#include "paddle/fluid/platform/mkldnn_helper.h"
18+
#endif
19+
20+
namespace paddle {
21+
namespace operators {
22+
23+
framework::OpKernelType ReQuantOp::GetExpectedKernelType(
24+
const framework::ExecutionContext& ctx) const {
25+
framework::LibraryType library_ = framework::LibraryType::kMKLDNN;
26+
framework::DataLayout layout_ = framework::DataLayout::kMKLDNN;
27+
28+
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
29+
ctx.GetPlace(), layout_, library_);
30+
}
31+
32+
void ReQuantOpMaker::Make() {
33+
AddInput("Input", "input data");
34+
AddOutput("Output", "output data");
35+
AddAttr<float>("Scale_in", "scale in data").SetDefault({1.0f});
36+
AddAttr<float>("Scale_out", "scale out data").SetDefault({1.0f});
37+
AddComment(
38+
R"DOC(This op will re-quantize data from INT8 with scale_in to INT8 with scale_out)DOC");
39+
}
40+
41+
} // namespace operators
42+
} // namespace paddle
43+
namespace ops = paddle::operators;
44+
45+
REGISTER_OPERATOR(requantize, ops::ReQuantOp, ops::ReQuantOpMaker,
46+
paddle::framework::DefaultGradOpDescMaker<true>);
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <string>
18+
#include <vector>
19+
#include "paddle/fluid/framework/op_registry.h"
20+
21+
namespace paddle {
22+
namespace operators {
23+
24+
using framework::OpKernelType;
25+
using framework::Tensor;
26+
27+
class ReQuantOp : public framework::OperatorWithKernel {
28+
public:
29+
using framework::OperatorWithKernel::OperatorWithKernel;
30+
31+
void InferShape(framework::InferShapeContext* ctx) const override {
32+
ctx->SetOutputDim("Output", ctx->GetInputDim("Input"));
33+
ctx->ShareLoD("Input", /*->*/ "Output");
34+
}
35+
36+
protected:
37+
framework::OpKernelType GetExpectedKernelType(
38+
const framework::ExecutionContext& ctx) const override;
39+
};
40+
41+
class ReQuantOpMaker : public framework::OpProtoAndCheckerMaker {
42+
public:
43+
void Make() override;
44+
};
45+
46+
} // namespace operators
47+
} // namespace paddle

python/paddle/fluid/tests/unittests/mkldnn/mkldnn_op_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,17 @@ def __assert_close(tensor, np_array, msg, atol=1e-4):
7070
fetch_list=['x@GRAD', 'out'])
7171

7272
__assert_close(x_grad, out[0], 'x@GRAD')
73+
74+
75+
def format_reorder(out, size):
76+
in_n = size[0]
77+
out_h = size[2]
78+
out_w = size[3]
79+
out_c = size[1]
80+
out_tmp = np.zeros((in_n, out_h, out_w, out_c))
81+
for n in range(in_n):
82+
for i in range(out_h):
83+
for j in range(out_w):
84+
for m in range(out_c):
85+
out_tmp[n, i, j, m] = out[n, m, i, j]
86+
return out_tmp.reshape(in_n, out_c, out_h, out_w)

python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_int8_mkldnn_op.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import paddle.fluid.core as core
2121
from paddle.fluid.tests.unittests.op_test import OpTest
2222
from paddle.fluid.tests.unittests.test_conv2d_op import conv2d_forward_naive, TestConv2dOp
23+
from mkldnn_op_test import format_reorder
2324

2425

2526
def conv2d_forward_refer(input, filter, group, conv_param):
@@ -29,20 +30,6 @@ def conv2d_forward_refer(input, filter, group, conv_param):
2930
return format_reorder(out, size)
3031

3132

32-
def format_reorder(out, size):
33-
in_n = size[0]
34-
out_h = size[2]
35-
out_w = size[3]
36-
out_c = size[1]
37-
out_tmp = np.zeros((in_n, out_h, out_w, out_c))
38-
for n in range(in_n):
39-
for i in range(out_h):
40-
for j in range(out_w):
41-
for m in range(out_c):
42-
out_tmp[n, i, j, m] = out[n, m, i, j]
43-
return out_tmp.reshape(in_n, out_c, out_h, out_w)
44-
45-
4633
class TestConv2dInt8Op(TestConv2dOp):
4734
def setUp(self):
4835
self.op_type = "conv2d"
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import unittest
18+
import numpy as np
19+
from paddle.fluid.tests.unittests.op_test import OpTest
20+
from mkldnn_op_test import format_reorder
21+
22+
23+
class TestReQuantizeOp(OpTest):
24+
def setUp(self):
25+
self.op_type = 'requantize'
26+
self.scale_in = 2.0
27+
self.scale_out = 1.5
28+
self.input_size = [1, 1, 5, 5]
29+
self.data_type = 'int8'
30+
self.set_scale()
31+
self.set_data_type()
32+
33+
scale_shift = self.scale_out / self.scale_in
34+
35+
if self.data_type == 'int8':
36+
input = (np.random.randint(0, 100, self.input_size) - 50
37+
).astype(self.data_type)
38+
output_tmp = np.round(input.astype('float32') *
39+
scale_shift).astype('int8')
40+
else:
41+
input = (np.random.randint(0, 100,
42+
self.input_size)).astype(self.data_type)
43+
output_tmp = np.round(input.astype('float32') *
44+
scale_shift).astype('uint8')
45+
46+
output = format_reorder(output_tmp, self.input_size)
47+
48+
self.inputs = {'Input': OpTest.np_dtype_to_fluid_dtype(input)}
49+
50+
self.outputs = {'Output': output}
51+
52+
self.attrs = {'Scale_in': self.scale_in, 'Scale_out': self.scale_out}
53+
54+
def test_check_output(self):
55+
self.check_output()
56+
57+
def set_scale(self):
58+
pass
59+
60+
def set_data_type(OpTest):
61+
pass
62+
63+
64+
#--------------------test requantize with s8 input--------------------
65+
66+
67+
class TestReQuantizeOp1(TestReQuantizeOp):
68+
def set_scale(self):
69+
self.scale_in = 1.5
70+
self.scale_out = 1.5
71+
72+
73+
class TestReQuantizeOp2(TestReQuantizeOp):
74+
def set_scale(self):
75+
self.scale_in = 0.1
76+
self.scale_out = 0.2
77+
78+
79+
#--------------------test requantize with u8 input--------------------
80+
81+
82+
class TestReQuantizeOp3(TestReQuantizeOp1):
83+
def set_data_type(self):
84+
self.data_type = 'uint8'
85+
86+
87+
class TestReQuantizeOp4(TestReQuantizeOp2):
88+
def set_data_type(self):
89+
self.data_type = 'uint8'
90+
91+
92+
if __name__ == '__main__':
93+
unittest.main()

0 commit comments

Comments
 (0)