@@ -94,30 +94,24 @@ def is_same_span(cls, in_size, out_size):
94
94
95
95
@classmethod
96
96
def opset_1 (cls , graph , node , ** kw ):
97
- if node .attr ('global_pooling' ):
97
+ if node .attr ('global_pooling' ) or (node .attr ('adaptive' ) and
98
+ node .attr ('ksize' ) == [1 , 1 ]):
98
99
onnx_node = graph .make_node (
99
100
cls .pool_type [node .attr ('pooling_type' )][1 ],
100
101
inputs = node .input ('X' ),
101
102
outputs = node .output ('Out' ))
102
103
elif node .attr ('adaptive' ):
104
+ # if pool is adaptive, check if input shape of pool is fixed.
103
105
mapper_helper .is_static_shape (node .input_shape ('X' , 0 ))
104
-
105
106
input_h , input_w = node .input_shape ('X' , 0 )[2 :]
106
107
output_h , output_w = node .output_shape ('Out' , 0 )[2 :]
107
108
stride_h = int (input_h / output_h )
108
109
stride_w = int (input_w / output_w )
110
+
109
111
kernel_h = input_h - (output_h - 1 ) * stride_h
110
112
kernel_w = input_w - (output_w - 1 ) * stride_w
111
113
112
- if node .attr ('strides' ) is not None and (
113
- - 1 not in node .attr ('strides' )):
114
- stride_h , stride_w = node .attr ('strides' )
115
-
116
- if node .attr ('ksize' ) is not None and (
117
- - 1 not in node .attr ('ksize' )):
118
- kernel_h = input_h - (node .attr ('ksize' )[0 ] - 1 ) * stride_h
119
- kernel_w = input_w - (node .attr ('ksize' )[1 ] - 1 ) * stride_w
120
-
114
+ #check if kernel_size is fixed.
121
115
if not cls .is_same_span (input_h , output_h ) or not cls .is_same_span (
122
116
input_w , output_w ):
123
117
raise Exception (
0 commit comments