Skip to content

Commit 43c7be5

Browse files
author
channings
authored
Merge pull request #178 from Channingss/fix_adative_pool
when pool is adative and ksize is [1,1], use gobal_pool to map adative_pool.
2 parents 322af23 + fa75b0d commit 43c7be5

File tree

1 file changed

+5
-11
lines changed
  • paddle2onnx/op_mapper

1 file changed

+5
-11
lines changed

paddle2onnx/op_mapper/nn.py

+5-11
Original file line numberDiff line numberDiff line change
@@ -94,30 +94,24 @@ def is_same_span(cls, in_size, out_size):
9494

9595
@classmethod
9696
def opset_1(cls, graph, node, **kw):
97-
if node.attr('global_pooling'):
97+
if node.attr('global_pooling') or (node.attr('adaptive') and
98+
node.attr('ksize') == [1, 1]):
9899
onnx_node = graph.make_node(
99100
cls.pool_type[node.attr('pooling_type')][1],
100101
inputs=node.input('X'),
101102
outputs=node.output('Out'))
102103
elif node.attr('adaptive'):
104+
# if pool is adaptive, check if input shape of pool is fixed.
103105
mapper_helper.is_static_shape(node.input_shape('X', 0))
104-
105106
input_h, input_w = node.input_shape('X', 0)[2:]
106107
output_h, output_w = node.output_shape('Out', 0)[2:]
107108
stride_h = int(input_h / output_h)
108109
stride_w = int(input_w / output_w)
110+
109111
kernel_h = input_h - (output_h - 1) * stride_h
110112
kernel_w = input_w - (output_w - 1) * stride_w
111113

112-
if node.attr('strides') is not None and (
113-
-1 not in node.attr('strides')):
114-
stride_h, stride_w = node.attr('strides')
115-
116-
if node.attr('ksize') is not None and (
117-
-1 not in node.attr('ksize')):
118-
kernel_h = input_h - (node.attr('ksize')[0] - 1) * stride_h
119-
kernel_w = input_w - (node.attr('ksize')[1] - 1) * stride_w
120-
114+
#check if kernel_size is fixed.
121115
if not cls.is_same_span(input_h, output_h) or not cls.is_same_span(
122116
input_w, output_w):
123117
raise Exception(

0 commit comments

Comments
 (0)