Skip to content

Commit 4e2f4bc

Browse files
authored
Fix quantize model export for TRT backend (#830)
* fix bug in trt * update code * update code * update code * remove useless code
1 parent 3ed3454 commit 4e2f4bc

File tree

2 files changed

+28
-6
lines changed

2 files changed

+28
-6
lines changed

paddle2onnx/mapper/quantize/quantize_linear.cc

100644100755
File mode changed.

paddle2onnx/mapper/quantize_helper.cc

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -222,16 +222,34 @@ void QuantizeModelProcessor::AddTrtQDQ() {
222222
}
223223
quantize_tensors = tensor_names;
224224
}
225-
// An OP requires a separate quantize op
225+
226+
std::string negative_scale_tensor = "";
226227
for (std::string& name : quantize_tensors) {
227-
if (IsGraphOutput(name)) {
228-
continue;
229-
}
230228
Assert(
231229
helper_->quantize_info.find(name) != helper_->quantize_info.end(),
232230
"[QuantizeModelProcessor] Can not find quantize info for tensor: " +
233231
name);
234232
QuantizeInfo quantize_info = helper_->quantize_info[name];
233+
std::vector<float> scales = quantize_info.scale_;
234+
for (auto& i : scales) {
235+
if (i <= 1e-10) {
236+
negative_scale_tensor = negative_scale_tensor + " " + name;
237+
}
238+
}
239+
}
240+
if (negative_scale_tensor.size() > 0) {
241+
P2OLogger()
242+
<< "[Warning] The scale of tensors: [ " + negative_scale_tensor +
243+
" ] contains negative scale, so this OP will not be quantized."
244+
<< std::endl;
245+
continue;
246+
}
247+
// An OP requires a separate quantize op
248+
for (std::string& name : quantize_tensors) {
249+
if (IsGraphOutput(name)) {
250+
continue;
251+
}
252+
QuantizeInfo quantize_info = helper_->quantize_info[name];
235253
std::string scale_node = quantize_info.scale_node_;
236254
std::string zeros_node = quantize_info.zeros_node_;
237255
int64_t quantize_axis = quantize_info.quantize_axis_;
@@ -245,7 +263,11 @@ void QuantizeModelProcessor::AddTrtQDQ() {
245263
if (helper_->GetOpsetVersion() >= 13) {
246264
AddAttribute(dq_node, "axis", quantize_axis);
247265
}
248-
ReplaceInputOfAllNodes(name, dq_node->output(0));
266+
for (size_t i = 0; i < node->input_size(); ++i) {
267+
if (node->input(i) == name) {
268+
node->set_input(i, dq_node->output(0));
269+
}
270+
}
249271
}
250272
}
251273
}
@@ -984,4 +1006,4 @@ void QuantizeModelProcessor::AppendQuantizeTensor(const std::string& tensor,
9841006
}
9851007
}
9861008
}
987-
}
1009+
} // namespace paddle2onnx

0 commit comments

Comments
 (0)