@@ -137,35 +137,32 @@ def compute_loss(self, x, y, y_pred, sample_weight=None, **kwargs):
137
137
138
138
# TODO(ianstenbit): loss heatmap threshold should be configurable.
139
139
box_regression_mask = (
140
- ops .squeeze (
141
- ops .take (
142
- ops .reshape (heatmap , (ops .shape (heatmap )[0 ], - 1 )),
143
- index [..., 0 ] * ops .shape (heatmap )[1 ] + index [..., 1 ],
144
- axis = 1 ,
145
- ),
146
- axis = 0 ,
140
+ ops .take_along_axis (
141
+ ops .reshape (heatmap , (heatmap .shape [0 ], - 1 )),
142
+ index [..., 0 ] * heatmap .shape [1 ] + index [..., 1 ],
143
+ axis = 1 ,
147
144
)
148
145
> 0.95
149
146
)
150
147
151
- box = ops .squeeze (
152
- ops .take (
153
- ops .reshape (box , (ops .shape (box )[0 ], - 1 , 7 )),
154
- index [..., 0 ] * ops .shape (box )[1 ] + index [..., 1 ],
155
- axis = 1 ,
148
+ box = ops .take_along_axis (
149
+ ops .reshape (box , (ops .shape (box )[0 ], - 1 , 7 )),
150
+ ops .expand_dims (
151
+ index [..., 0 ] * ops .shape (box )[1 ] + index [..., 1 ], axis = - 1
156
152
),
157
- axis = 0 ,
153
+ axis = 1 ,
158
154
)
159
- box_pred = ops .squeeze (
160
- ops .take (
161
- ops .reshape (
162
- box_pred ,
163
- (ops .shape (box_pred )[0 ], - 1 , ops .shape (box_pred )[- 1 ]),
164
- ),
155
+
156
+ box_pred = ops .take_along_axis (
157
+ ops .reshape (
158
+ box_pred ,
159
+ (ops .shape (box_pred )[0 ], - 1 , ops .shape (box_pred )[- 1 ]),
160
+ ),
161
+ ops .expand_dims (
165
162
index [..., 0 ] * ops .shape (box_pred )[1 ] + index [..., 1 ],
166
- axis = 1 ,
163
+ axis = - 1 ,
167
164
),
168
- axis = 0 ,
165
+ axis = 1 ,
169
166
)
170
167
171
168
box_center_mask = heatmap > 0.99
0 commit comments