Skip to content

Commit 585e852

Browse files
【Paddle-TensorRT】fix pd_op.fused_conv2d_add_act (#72157)
* support fp32 * fix pd_op.fused_conv2d * simplified code --------- Co-authored-by: YuanRisheng <yuanrisheng@baidu.com>
1 parent 1befb85 commit 585e852

File tree

2 files changed

+34
-9
lines changed

2 files changed

+34
-9
lines changed

python/paddle/tensorrt/converter.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,9 @@ def convert_program_to_trt(self):
652652
tensor_data, dtype=out_dtype
653653
).tolist()
654654

655+
if isinstance(constant_array, (int, float)):
656+
constant_array = [constant_array]
657+
655658
# convert builtin.constant to pd_op.full_int_array/full and then delete it
656659
with paddle.pir.core.program_guard(self.program):
657660
paddle.base.libpaddle.pir.reset_insertion_point_to_start()

python/paddle/tensorrt/converter_utils.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -669,16 +669,38 @@ def convert_conv2d(network, paddle_op, inputs):
669669
if paddle_op.name() == "pd_op.fused_conv2d_add_act":
670670
constant_manager = TensorRTConstantManager()
671671
bias_source_op = paddle_op.operands()[2].source().get_defining_op()
672-
if bias_source_op.name() == "builtin.parameter":
673-
bias_name = bias_source_op.attrs()['parameter_name']
674-
elif bias_source_op.name() == "builtin.constant":
675-
bias_np = bias_source_op.attrs()['value']
672+
673+
def get_bias_weights(current_op):
674+
if current_op.name() == "builtin.parameter":
675+
bias_name = current_op.attrs()["parameter_name"]
676+
elif current_op.name() == "builtin.constant":
677+
bias_name = current_op.attrs()["value"]
678+
else:
679+
raise ValueError(
680+
f"Unsupported bias source operation: {current_op.name()}"
681+
)
682+
683+
bias_np = constant_manager.get_constant_value(bias_name)
684+
return trt.Weights(bias_np)
685+
686+
if bias_source_op.name() in ["builtin.parameter", "builtin.constant"]:
687+
bias_weights = get_bias_weights(bias_source_op)
676688
else:
677-
raise ValueError(
678-
f"Unsupported bias source op: {bias_source_op.name()}"
679-
)
680-
bias_np = constant_manager.get_constant_value(bias_name)
681-
bias_weights = trt.Weights(bias_np)
689+
while bias_source_op.name() == "pd_op.reshape":
690+
bias_source_op = (
691+
bias_source_op.operands()[0].source().get_defining_op()
692+
)
693+
if bias_source_op.name() in [
694+
"builtin.parameter",
695+
"builtin.constant",
696+
]:
697+
bias_weights = get_bias_weights(bias_source_op)
698+
break
699+
else:
700+
raise ValueError(
701+
f"Unsupported bias source operation: {bias_source_op.name()}"
702+
)
703+
682704
layer = network.add_convolution_nd(
683705
input=input_tensor,
684706
num_output_maps=n_output,

0 commit comments

Comments
 (0)